1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #define USE_DEPRECATED_API
17 #include "tools/optimizer/fusion/flash_attention_fusion.h"
18 #include <memory>
19 #include <utility>
20 #include <string>
21 #include "ops/op_utils.h"
22 #include "ops/array_ops.h"
23 #include "ops/nn_ops.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "mindspore/core/ops/lite_ops.h"
26 #include "ops/incre_flash_attention.h"
27 #include "ops/prompt_flash_attention.h"
28 #include "ops/fusion/pad_fusion.h"
29 #include "ops/slice.h"
30 #include "ops/auto_generate/gen_lite_ops.h"
31
32 namespace mindspore::opt {
33 namespace {
34 static int kNameIndex = 0;
35 constexpr auto kNameFlashAttentionPatternForMsSD21 = "FlashAttentionPatternForMsSD21";
36 constexpr auto kNameFlashAttentionPatternForMsSDXL = "FlashAttentionPatternForMsSDXL";
37 constexpr auto kNameFlashAttentionPatternForVideoComposer = "FlashAttentionPatternForVideoComposer";
38 constexpr auto kNameFlashAttentionPatternForSDBSH = "FlashAttentionPatternForSDBSH";
39 constexpr auto kNameFlashAttentionPatternForSDWithoutCast = "FlashAttentionPatternForSDWithoutCast";
40 constexpr auto kNameFlashAttentionPatternForSDPreMul = "FlashAttentionPatternForSDPreMul";
41 constexpr auto kNameFlashAttentionPatternForPanGu = "FlashAttentionPatternForPanGu";
42 constexpr auto kNameFlashAttentionPatternForLLAMAPatternV1 = "FlashAttentionPatternForLLAMAPatternV1";
43 constexpr auto kNameFlashAttentionPatternForLLAMAPatternV2 = "FlashAttentionPatternForLLAMAPatternV2";
44 constexpr auto kNameFlashAttentionPatternForBaiChuan = "FlashAttentionPatternForBaiChuan";
45 constexpr auto kNameFlashAttentionPatternForMsSDPseShift = "FlashAttentionPatternForMsSDPseShift";
46 constexpr auto kNameFlashAttentionPatternForSDEinsum = "FlashAttentionPatternForSDEinsum";
47 constexpr auto kNamePadNodeSuffix = "_fa_pad";
48 constexpr auto kSocVersionAscend310P = "Ascend310P3";
49 constexpr size_t high_inner_precise = 0;
50 constexpr size_t high_performance = 1;
51 constexpr size_t kNumIndex0 = 0;
52 constexpr size_t kNumIndex1 = 1;
53 constexpr size_t kNumIndex2 = 2;
54 constexpr size_t kNumIndex3 = 3;
55 constexpr size_t kNumDimSize4 = 4;
56 constexpr size_t kNumShapeSize2 = 2;
57 constexpr size_t kNumShapeSize3 = 3;
58 constexpr size_t kNumShapeSize4 = 4;
59 constexpr int64_t kNumMaxBatchLenSize = 128;
60 constexpr int64_t kNumMaxNextTokenSize = 65535;
61 constexpr int kNumMultiple32 = 32;
62 constexpr int kNumMultiple16 = 16;
63 constexpr int64_t kNumDValue = 40;
64 constexpr int64_t kNumPadSize = 8;
65 constexpr int kNumPowerTwo = 2;
66 constexpr float kNumPowerHalf = 0.5;
67
IsDivNode(const BaseRef & n)68 bool IsDivNode(const BaseRef &n) {
69 if (utils::isa<AnfNodePtr>(n)) {
70 auto anf_node = utils::cast<AnfNodePtr>(n);
71 return CheckPrimitiveType(anf_node, prim::kPrimDiv) || CheckPrimitiveType(anf_node, prim::kPrimRealDiv);
72 }
73 return false;
74 }
75
IsGQAPattern(const CNodePtr qk_matmul,const CNodePtr v_matmul)76 bool IsGQAPattern(const CNodePtr qk_matmul, const CNodePtr v_matmul) {
77 auto k_reshape = qk_matmul->input(kNumIndex2)->cast<CNodePtr>();
78 if (!CheckPrimitiveType(k_reshape, prim::kPrimReshape)) {
79 return false;
80 }
81 auto k_tile = k_reshape->input(kNumIndex1)->cast<CNodePtr>();
82 if (!CheckPrimitiveType(k_tile, prim::kPrimTile)) {
83 return false;
84 }
85 auto v_reshape = v_matmul->input(kNumIndex2)->cast<CNodePtr>();
86 if (!CheckPrimitiveType(v_reshape, prim::kPrimReshape)) {
87 return false;
88 }
89 auto v_tile = v_reshape->input(kNumIndex1)->cast<CNodePtr>();
90 if (!CheckPrimitiveType(v_tile, prim::kPrimTile)) {
91 return false;
92 }
93 return true;
94 }
95
PFACheckShape(float scale_value,const std::vector<int64_t> & q_shape,const std::vector<int64_t> & k_shape,const std::vector<int64_t> & v_shape,int64_t seq_threshold=0)96 bool PFACheckShape(float scale_value, const std::vector<int64_t> &q_shape, const std::vector<int64_t> &k_shape,
97 const std::vector<int64_t> &v_shape, int64_t seq_threshold = 0) {
98 if (scale_value < 0) {
99 MS_LOG(WARNING) << "scale value is invalid.";
100 return false;
101 }
102 int64_t d_value = 0;
103 if (q_shape.size() == kNumShapeSize4 && k_shape.size() == kNumShapeSize4 && v_shape.size() == kNumShapeSize4) {
104 MS_LOG(INFO) << "get flash attention param for static shape.";
105 if (q_shape[kNumIndex0] >= kNumMaxBatchLenSize) {
106 MS_LOG(INFO) << "fa not support";
107 return false;
108 }
109 // for static shape: get scale value
110 scale_value = 1 / (pow(q_shape[kNumIndex3], kNumPowerHalf));
111 auto q_seq_len = q_shape[kNumIndex2];
112 auto k_seq_len = k_shape[kNumIndex2];
113 auto v_seq_len = v_shape[kNumIndex2];
114 d_value = q_shape[kNumIndex3];
115 MS_LOG(INFO) << "check param in stable diffusion models, scale_value: " << scale_value
116 << ", q_seq_len: " << q_seq_len << ", k_seq_len: " << k_seq_len << ", v_seq_len: " << v_seq_len
117 << ", d_value: " << d_value;
118 if (q_seq_len < seq_threshold || k_seq_len < seq_threshold || v_seq_len < seq_threshold) {
119 MS_LOG(INFO) << "seq <= seq_threshold, seq_threshold is: " << seq_threshold;
120 return false;
121 }
122 } else {
123 d_value = std::lround(1 / pow(scale_value, kNumPowerTwo));
124 }
125 if (static_cast<int>(d_value) % kNumMultiple16 == 0 || static_cast<int>(d_value) == kNumDValue) {
126 return true;
127 }
128 MS_LOG(INFO) << "D value must be an integer multiple of 16 or D is 40, d value: " << static_cast<int>(d_value);
129 return false;
130 }
131
GetReshapeParam(const AnfNodePtr & reshape_node,size_t index)132 int32_t GetReshapeParam(const AnfNodePtr &reshape_node, size_t index) {
133 if (!utils::isa<CNodePtr>(reshape_node)) {
134 MS_LOG(INFO) << "reshape_node is not CNode!";
135 return -1;
136 }
137 auto reshape_cnode = reshape_node->cast<CNodePtr>();
138 if (reshape_cnode->inputs().size() < kNumShapeSize3) {
139 MS_LOG(WARNING) << "reshape_cnode size < 3!";
140 return -1;
141 }
142 auto reshape_input_2 = reshape_cnode->input(kNumIndex2);
143 if (!utils::isa<ParameterPtr>(reshape_input_2)) {
144 MS_LOG(INFO) << "reshape_input_2 is not ParameterPtr!";
145 return -1;
146 }
147 auto reshape_param = reshape_input_2->cast<ParameterPtr>();
148 if (reshape_param == nullptr) {
149 MS_LOG(WARNING) << "reshape_param is nullptr!";
150 return -1;
151 }
152 auto reshape_default_param = reshape_param->default_param();
153 if (reshape_default_param == nullptr) {
154 MS_LOG(WARNING) << "reshape_default_param is nullptr!";
155 return -1;
156 }
157 auto reshape_value = std::dynamic_pointer_cast<tensor::Tensor>(reshape_default_param);
158 if (reshape_value == nullptr) {
159 MS_LOG(WARNING) << "reshape_value is nullptr!";
160 return -1;
161 }
162 if (reshape_value->ElementsNum() != kNumShapeSize4) {
163 MS_LOG(WARNING) << "reshape_value elements num is not 4, ElementsNum is: " << reshape_value->ElementsNum();
164 return -1;
165 }
166 if (reshape_value->data_type() != kNumberTypeInt32) {
167 MS_LOG(WARNING) << "reshape_value is not or int32, now not support other data type.";
168 return -1;
169 }
170 auto reshape_data = static_cast<int32_t *>(reshape_value->data_c());
171 if (reshape_data == nullptr) {
172 MS_LOG(WARNING) << "reshape_data is nullptr.";
173 return -1;
174 }
175 return static_cast<int64_t>(reshape_data[index]);
176 }
177
GetNumHeadForSD(const AnfNodePtr & q_trans_reshape)178 int64_t GetNumHeadForSD(const AnfNodePtr &q_trans_reshape) {
179 auto concat_cnode = q_trans_reshape->cast<CNodePtr>()->input(kNumIndex2)->cast<CNodePtr>();
180 if (concat_cnode == nullptr) {
181 MS_LOG(WARNING) << "concat_cnode is nullptr.";
182 return -1;
183 }
184 auto concat_const_input = concat_cnode->input(kNumIndex3);
185 if (!utils::isa<ParameterPtr>(concat_const_input)) {
186 MS_LOG(WARNING) << "concat_const_input is not ParameterPtr .";
187 return -1;
188 }
189 auto concat_param = concat_cnode->input(kNumIndex3)->cast<ParameterPtr>()->default_param();
190 if (concat_param == nullptr) {
191 MS_LOG(WARNING) << "concat_param is nullptr.";
192 return -1;
193 }
194 auto concat_value = std::dynamic_pointer_cast<tensor::Tensor>(concat_param);
195 if (concat_value == nullptr) {
196 MS_LOG(WARNING) << "concat_value is nullptr.";
197 return -1;
198 }
199 if (concat_value->ElementsNum() != 1) {
200 MS_LOG(WARNING) << "concat value elements num is not 1, ElementsNum is: " << concat_value->ElementsNum();
201 return -1;
202 }
203 if (concat_value->data_type() == kNumberTypeInt32) {
204 auto concat_data = static_cast<int32_t *>(concat_value->data_c());
205 if (concat_data == nullptr) {
206 MS_LOG(WARNING) << "concat_data is nullptr.";
207 return -1;
208 }
209 return static_cast<int64_t>(concat_data[0]);
210 } else if (concat_value->data_type() == kNumberTypeInt64) {
211 auto concat_data = static_cast<int64_t *>(concat_value->data_c());
212 if (concat_data == nullptr) {
213 MS_LOG(WARNING) << "concat_data is nullptr.";
214 return -1;
215 }
216 return static_cast<int64_t>(concat_data[0]);
217 } else {
218 MS_LOG(WARNING) << "head num is not int32 or int64, now not support other data type.";
219 return -1;
220 }
221 }
222
223 // reshape(matmul, concat)->trans(reshape)->FA
CheckIpAdapterInput(const CNodePtr & input)224 bool CheckIpAdapterInput(const CNodePtr &input) {
225 auto trans = input;
226 MS_CHECK_TRUE_RET(trans != nullptr, false);
227 if (!CheckPrimitiveType(trans, prim::kPrimTranspose)) {
228 MS_LOG(INFO) << "node is not check op type: " << trans->fullname_with_scope();
229 return false;
230 }
231
232 if (trans->inputs().size() != kNumShapeSize3) {
233 return false;
234 }
235 auto reshape = trans->input(kNumIndex1)->cast<CNodePtr>();
236 MS_CHECK_TRUE_RET(reshape != nullptr, false);
237 if (!CheckPrimitiveType(reshape, prim::kPrimReshape)) {
238 MS_LOG(INFO) << "node is not check op type: " << reshape->fullname_with_scope();
239 return false;
240 }
241
242 if (reshape->inputs().size() != kNumShapeSize3) {
243 return false;
244 }
245 auto matmul = reshape->input(kNumIndex1)->cast<CNodePtr>();
246 MS_CHECK_TRUE_RET(matmul != nullptr, false);
247 if (!CheckPrimitiveType(matmul, prim::kPrimMatMulFusion)) {
248 MS_LOG(INFO) << "node is not check op type: " << matmul->fullname_with_scope();
249 return false;
250 }
251 auto concat = reshape->input(kNumIndex2)->cast<CNodePtr>();
252 MS_CHECK_TRUE_RET(concat != nullptr, false);
253 if (!CheckPrimitiveType(concat, prim::kPrimConcat)) {
254 MS_LOG(INFO) << "node is not check op type: " << concat->fullname_with_scope();
255 return false;
256 }
257 return true;
258 }
259
IpAdapterPattern(const CNodePtr q_input,const CNodePtr k_input)260 bool IpAdapterPattern(const CNodePtr q_input, const CNodePtr k_input) {
261 return CheckIpAdapterInput(q_input) && CheckIpAdapterInput(k_input);
262 }
263
PD2DecoderPattern(const CNodePtr & q_trans_BNSD)264 const CNodePtr PD2DecoderPattern(const CNodePtr &q_trans_BNSD) {
265 auto q_reshape_BSND = q_trans_BNSD->input(kNumIndex1)->cast<CNodePtr>();
266 MS_CHECK_TRUE_RET(q_reshape_BSND != nullptr && q_reshape_BSND->inputs().size() == kNumShapeSize3, nullptr);
267 if (!CheckPrimitiveType(q_reshape_BSND, prim::kPrimReshape)) {
268 MS_LOG(INFO) << "node is not check op type: " << q_reshape_BSND->fullname_with_scope();
269 return nullptr;
270 }
271
272 auto q_trans_BSH = q_reshape_BSND->input(kNumIndex1)->cast<CNodePtr>();
273 MS_CHECK_TRUE_RET(q_trans_BSH != nullptr && q_trans_BSH->inputs().size() == kNumShapeSize3, nullptr);
274 if (!CheckPrimitiveType(q_trans_BSH, prim::kPrimTranspose)) {
275 MS_LOG(INFO) << "node is not check op type: " << q_trans_BSH->fullname_with_scope();
276 return nullptr;
277 }
278
279 auto q_reshape_BSH = q_trans_BSH->input(kNumIndex1)->cast<CNodePtr>();
280 MS_CHECK_TRUE_RET(q_reshape_BSH != nullptr && q_reshape_BSH->inputs().size() == kNumShapeSize3, nullptr);
281 if (!CheckPrimitiveType(q_reshape_BSH, prim::kPrimReshape)) {
282 MS_LOG(INFO) << "node is not check op type: " << q_reshape_BSH->fullname_with_scope();
283 return nullptr;
284 }
285
286 auto q_conv = q_reshape_BSH->input(kNumIndex1)->cast<CNodePtr>();
287 if (!CheckPrimitiveType(q_conv, prim::kPrimConv2DFusion)) {
288 MS_LOG(INFO) << "node is not check op type: " << q_conv->fullname_with_scope();
289 return nullptr;
290 }
291 return q_conv;
292 }
293
GetTensorShape(CNodePtr cnode,size_t input_index)294 std::vector<int64_t> GetTensorShape(CNodePtr cnode, size_t input_index) {
295 auto abstract = GetCNodeInputAbstract(cnode, input_index);
296 if (abstract == nullptr) {
297 MS_LOG(ERROR) << "GetCNodeInputAbstract in promapt flash attention fusion.";
298 return {};
299 }
300 std::vector<int64_t> shape = {};
301 if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
302 MS_LOG(ERROR) << "FetchShapeFromAbstract failed.";
303 return {};
304 }
305 return shape;
306 }
307
GetParamForIpAdapterPattern(const CNodePtr & q_trans_BNSD,const CNodePtr & k_trans_BNDS,int64_t * num_head,int64_t * d_value)308 bool GetParamForIpAdapterPattern(const CNodePtr &q_trans_BNSD, const CNodePtr &k_trans_BNDS, int64_t *num_head,
309 int64_t *d_value) {
310 if (num_head == nullptr) {
311 MS_LOG(ERROR) << "GetParamForIpAdapterPattern failed, num_head is nullptr!";
312 return false;
313 }
314 if (d_value == nullptr) {
315 MS_LOG(ERROR) << "GetParamForIpAdapterPattern failed, d_value is nullptr!";
316 return false;
317 }
318 auto q_reshape_BSND = q_trans_BNSD->input(kNumIndex1)->cast<CNodePtr>();
319 MS_CHECK_TRUE_RET(q_reshape_BSND != nullptr, false);
320 auto q_top_matmul = q_reshape_BSND->input(kNumIndex1)->cast<CNodePtr>();
321 MS_CHECK_TRUE_RET(q_top_matmul != nullptr, false);
322 auto q_matmul_input_2_shape = GetTensorShape(q_top_matmul, kNumIndex2);
323 auto k_reshape_BSND = k_trans_BNDS->input(kNumIndex1)->cast<CNodePtr>();
324 MS_CHECK_TRUE_RET(k_reshape_BSND != nullptr, false);
325 auto k_top_matmul = k_reshape_BSND->input(kNumIndex1)->cast<CNodePtr>();
326 MS_CHECK_TRUE_RET(k_top_matmul != nullptr, false);
327 auto k_matmul_input_2_shape = GetTensorShape(k_top_matmul, kNumIndex2);
328 MS_LOG(INFO) << "q_top_matmul name: " << q_top_matmul->fullname_with_scope()
329 << ", k_top_matmul name: " << k_top_matmul->fullname_with_scope()
330 << ", q matmul input2 shape: " << q_matmul_input_2_shape
331 << " ,k matmul input2 shape: " << k_matmul_input_2_shape;
332 if (q_matmul_input_2_shape.size() != kNumShapeSize2 || k_matmul_input_2_shape.size() != kNumShapeSize2) {
333 MS_LOG(INFO) << "Matmul input 2 shape is not 2D, can not fusion FA.";
334 return false;
335 }
336 auto num_h = q_matmul_input_2_shape[kNumIndex1];
337 *num_head = GetNumHeadForSD(q_reshape_BSND);
338 *d_value = num_h / *num_head;
339 return true;
340 }
341 } // namespace
342
343 std::string FlashAttentionFusion::soc_version_;
344
DefinePatterns() const345 std::unordered_map<std::string, VectorRef> FlashAttentionFusion::DefinePatterns() const {
346 MS_LOG(INFO) << "start define flash attention fusion patterns.";
347 std::unordered_map<std::string, VectorRef> patterns;
348 patterns[kNameFlashAttentionPatternForMsSD21] = DefineFlashAttentionPatternForMsSD21();
349 patterns[kNameFlashAttentionPatternForMsSDXL] = DefineFlashAttentionPatternForMsSDXL();
350 patterns[kNameFlashAttentionPatternForVideoComposer] = DefineFlashAttentionPatternForVideoComposer();
351 patterns[kNameFlashAttentionPatternForSDBSH] = DefineFlashAttentionPatternForSDBSH();
352 patterns[kNameFlashAttentionPatternForSDWithoutCast] = DefineFlashAttentionPatternForSDWithoutCast();
353 patterns[kNameFlashAttentionPatternForSDPreMul] = DefineFlashAttentionPatternForSDPreMul();
354 patterns[kNameFlashAttentionPatternForPanGu] = DefineFlashAttentionPatternForPanGu();
355 patterns[kNameFlashAttentionPatternForLLAMAPatternV1] = DefineFlashAttentionPatternForLLAMAPatternV1();
356 patterns[kNameFlashAttentionPatternForLLAMAPatternV2] = DefineFlashAttentionPatternForLLAMAPatternV2();
357 patterns[kNameFlashAttentionPatternForBaiChuan] = DefineFlashAttentionPatternForBaiChuan();
358 patterns[kNameFlashAttentionPatternForMsSDPseShift] = DefineFlashAttentionPatternForMsSDPseShift();
359 patterns[kNameFlashAttentionPatternForSDEinsum] = DefineFlashAttentionPatternForSDEinsum();
360 return patterns;
361 }
362
CreatePadCNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,int32_t pad_size,const std::string & node_name) const363 CNodePtr FlashAttentionFusion::CreatePadCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int32_t pad_size,
364 const std::string &node_name) const {
365 MS_LOG(INFO) << "add pad node for prompt flash attention.";
366 if (node->fullname_with_scope().find(kNamePadNodeSuffix) != std::string::npos) {
367 MS_LOG(WARNING) << "node name contains " << kNamePadNodeSuffix << ", pad node is not created";
368 return node->cast<CNodePtr>();
369 }
370 auto pad_prim = std::make_shared<ops::PadFusion>();
371 if (pad_prim == nullptr) {
372 MS_LOG(ERROR) << "new pad prim failed, prim is nullptr.";
373 return nullptr;
374 }
375
376 pad_prim->AddAttr("padding_mode", api::MakeValue(PaddingMode::CONSTANT));
377 pad_prim->AddAttr("constant_value", api::MakeValue(0.0));
378 std::vector<std::vector<int32_t>> paddings = {{0, 0}, {0, 0}, {0, 0}, {0, pad_size}};
379
380 auto pad_prim_c = pad_prim->GetPrim();
381 if (pad_prim_c == nullptr) {
382 MS_LOG(WARNING) << "pad_prim_c is nullptr.";
383 return nullptr;
384 }
385 AnfNodePtr paddings_node = BuildIntVec2DParameterNode(
386 func_graph, paddings, node->fullname_with_scope() + std::to_string(kNameIndex) + "_paddings");
387 if (paddings_node == nullptr) {
388 MS_LOG(WARNING) << "paddings_node is nullptr.";
389 return nullptr;
390 }
391 auto inputs = {node, paddings_node};
392 auto pad_cnode = func_graph->NewCNode(pad_prim_c, inputs);
393 if (pad_cnode == nullptr) {
394 MS_LOG(ERROR) << "new pad cnode failed, cnode is nulpptr.";
395 return nullptr;
396 }
397 pad_cnode->set_fullname_with_scope(node->fullname_with_scope() + std::to_string(kNameIndex++) + kNamePadNodeSuffix);
398 if (node->abstract() != nullptr) {
399 pad_cnode->set_abstract(node->abstract()->Clone());
400 }
401 MS_LOG(INFO) << "create pad node end.";
402 return pad_cnode;
403 }
404
CreateSliceCNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,int32_t slice_size) const405 CNodePtr FlashAttentionFusion::CreateSliceCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
406 int32_t slice_size) const {
407 MS_LOG(INFO) << "add slice node for prompt flash attention.";
408 auto slice_prim = std::make_shared<ops::Slice>();
409 if (slice_prim == nullptr) {
410 MS_LOG(ERROR) << "new pad prim failed, prim is nullptr.";
411 return nullptr;
412 }
413
414 std::vector<int32_t> begin = {0, 0, 0, 0};
415 std::vector<int32_t> size = {-1, -1, -1, slice_size};
416
417 auto slice_prim_c = slice_prim->GetPrim();
418 if (slice_prim_c == nullptr) {
419 MS_LOG(ERROR) << "slice prim c is nullptr.";
420 return nullptr;
421 }
422
423 AnfNodePtr begin_node = BuildIntVecParameterNode(func_graph, begin, node->fullname_with_scope() + "_begin");
424 if (begin_node == nullptr) {
425 MS_LOG(WARNING) << "BuildIntVecParameterNode failed.";
426 return nullptr;
427 }
428 AnfNodePtr size_node = BuildIntVecParameterNode(func_graph, size, node->fullname_with_scope() + "_size");
429 if (size_node == nullptr) {
430 MS_LOG(WARNING) << "BuildIntVecParameterNode failed.";
431 return nullptr;
432 }
433
434 auto inputs = {node, begin_node, size_node};
435 auto slice_cnode = func_graph->NewCNode(slice_prim_c, inputs);
436 if (slice_cnode == nullptr) {
437 MS_LOG(WARNING) << "create slice_cnode failed.";
438 return nullptr;
439 }
440
441 slice_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_fa_slice");
442 if (node->abstract() != nullptr) {
443 slice_cnode->set_abstract(node->abstract()->Clone());
444 }
445 MS_LOG(INFO) << "create slice node end.";
446 return slice_cnode;
447 }
448
DefineFlashAttentionPatternForMsSD21() const449 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForMsSD21() const {
450 // reshape Q
451 auto q_input = std::make_shared<Var>();
452 auto reshape_q_input_2 = std::make_shared<Var>();
453 MS_CHECK_TRUE_RET(reshape_q_input_2 != nullptr, {});
454 auto is_reshape_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
455 MS_CHECK_TRUE_RET(is_reshape_q != nullptr, {});
456 auto reshape_q = VectorRef({is_reshape_q, q_input, reshape_q_input_2});
457
458 // transpose
459 auto k_input = std::make_shared<Var>();
460 auto is_transpose_param = std::make_shared<Var>();
461 MS_CHECK_TRUE_RET(is_transpose_param != nullptr, {});
462 auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
463 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
464 auto transpose = VectorRef({is_transpose, k_input, is_transpose_param});
465
466 // matmul 1
467 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
468 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
469 auto matmul_1 = VectorRef({is_matmul_1, reshape_q, transpose});
470 // q mul
471 auto is_mul_param = std::make_shared<Var>();
472 MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
473 auto is_mul_qk = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
474 MS_CHECK_TRUE_RET(is_mul_qk != nullptr, {});
475 auto mul_qk = VectorRef({is_mul_qk, matmul_1, is_mul_param});
476 // softmax
477 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
478 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
479 auto softmax = VectorRef({is_softmax, mul_qk});
480
481 // matmul 2
482 auto v = std::make_shared<Var>(); // input V
483 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
484 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
485 auto matmul_2 = VectorRef({is_matmul_2, softmax, v});
486 return matmul_2;
487 }
488
DefineFlashAttentionPatternForMsSDPseShift() const489 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForMsSDPseShift() const {
490 // Q
491 auto q_trans_input = std::make_shared<Var>();
492 MS_CHECK_TRUE_RET(q_trans_input != nullptr, {});
493 auto reshape_q_input_2 = std::make_shared<Var>();
494 MS_CHECK_TRUE_RET(reshape_q_input_2 != nullptr, {});
495 auto is_reshape_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
496 MS_CHECK_TRUE_RET(is_reshape_q != nullptr, {});
497 auto reshape_q = VectorRef({is_reshape_q, q_trans_input, reshape_q_input_2});
498
499 // K
500 auto k_trans_input = std::make_shared<Var>();
501 MS_CHECK_TRUE_RET(k_trans_input != nullptr, {});
502 auto reshape_k_input_2 = std::make_shared<Var>();
503 MS_CHECK_TRUE_RET(reshape_k_input_2 != nullptr, {});
504 auto is_reshape_k = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
505 MS_CHECK_TRUE_RET(is_reshape_k != nullptr, {});
506 auto reshape_k = VectorRef({is_reshape_k, k_trans_input, reshape_k_input_2});
507 MS_CHECK_TRUE_RET(reshape_k != nullptr, {});
508 auto is_transpose_param = std::make_shared<Var>();
509 MS_CHECK_TRUE_RET(is_transpose_param != nullptr, {});
510 auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
511 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
512 auto transpose = VectorRef({is_transpose, reshape_k, is_transpose_param});
513
514 // matmul
515 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
516 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
517 auto matmul_1 = VectorRef({is_matmul_1, reshape_q, transpose});
518
519 // mul
520 auto is_mul_param = std::make_shared<Var>();
521 MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
522 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
523 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
524 auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
525
526 // add
527 auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
528 MS_CHECK_TRUE_RET(is_add != nullptr, {});
529 auto add_input_2 = std::make_shared<Var>();
530 MS_CHECK_TRUE_RET(add_input_2 != nullptr, {});
531 auto add = VectorRef({is_add, mul, add_input_2});
532
533 // softmax
534 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
535 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
536 auto softmax = VectorRef({is_softmax, add});
537
538 // cast
539 auto is_cast_param = std::make_shared<Var>();
540 MS_CHECK_TRUE_RET(is_cast_param != nullptr, {});
541 auto is_cast = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
542 MS_CHECK_TRUE_RET(is_cast != nullptr, {});
543 auto cast = VectorRef({is_cast, softmax, is_cast_param});
544
545 // V
546 auto v_trans_input = std::make_shared<Var>();
547 MS_CHECK_TRUE_RET(v_trans_input != nullptr, {});
548 auto reshape_v_input_2 = std::make_shared<Var>();
549 MS_CHECK_TRUE_RET(reshape_v_input_2 != nullptr, {});
550 auto is_reshape_v = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
551 MS_CHECK_TRUE_RET(is_reshape_v != nullptr, {});
552 auto reshape_v = VectorRef({is_reshape_v, v_trans_input, reshape_v_input_2});
553
554 // matmul 2
555 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
556 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
557 auto matmul_2 = VectorRef({is_matmul_2, cast, reshape_v});
558
559 // reshape
560 auto reshape_output_input_2 = std::make_shared<Var>();
561 MS_CHECK_TRUE_RET(reshape_output_input_2 != nullptr, {});
562 auto is_reshape_output = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
563 MS_CHECK_TRUE_RET(is_reshape_output != nullptr, {});
564 auto reshape_output = VectorRef({is_reshape_output, matmul_2, reshape_output_input_2});
565 return reshape_output;
566 }
567
DefineFlashAttentionPatternForMsSDXL() const568 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForMsSDXL() const {
569 // matmul 1
570 auto q = std::make_shared<Var>(); // input Q
571 MS_CHECK_TRUE_RET(q != nullptr, {});
572 auto k = std::make_shared<Var>(); // input K
573 MS_CHECK_TRUE_RET(k != nullptr, {});
574 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
575 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
576 auto matmul_1 = VectorRef({is_matmul_1, q, k});
577 // q div
578 auto is_div_q_param = std::make_shared<Var>();
579 MS_CHECK_TRUE_RET(is_div_q_param != nullptr, {});
580 auto is_div_q = std::make_shared<CondVar>(IsDivNode);
581 MS_CHECK_TRUE_RET(is_div_q != nullptr, {});
582 auto div_q = VectorRef({is_div_q, matmul_1, is_div_q_param});
583 // softmax
584 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
585 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
586 auto softmax = VectorRef({is_softmax, div_q});
587
588 // matmul 2
589 auto v = std::make_shared<Var>(); // input V
590 MS_CHECK_TRUE_RET(v != nullptr, {});
591 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
592 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
593 auto matmul_2 = VectorRef({is_matmul_2, softmax, v});
594 return matmul_2;
595 }
596
DefineFlashAttentionPatternForVideoComposer() const597 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForVideoComposer() const {
598 // q trans
599 auto input_q = std::make_shared<Var>();
600 MS_CHECK_TRUE_RET(input_q != nullptr, {});
601 auto input_q_perm = std::make_shared<Var>();
602 MS_CHECK_TRUE_RET(input_q_perm != nullptr, {});
603 auto is_q_transpese = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
604 MS_CHECK_TRUE_RET(is_q_transpese != nullptr, {});
605 auto q_transpose = VectorRef({is_q_transpese, input_q, input_q_perm});
606 // q reshape
607 auto reshape_q_input = std::make_shared<Var>();
608 MS_CHECK_TRUE_RET(reshape_q_input != nullptr, {});
609 auto is_q_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
610 MS_CHECK_TRUE_RET(is_q_reshape != nullptr, {});
611 auto reshape_q = VectorRef({is_q_reshape, q_transpose, reshape_q_input});
612 // k trans
613 auto input_k = std::make_shared<Var>();
614 MS_CHECK_TRUE_RET(input_k != nullptr, {});
615 auto input_k_perm = std::make_shared<Var>();
616 MS_CHECK_TRUE_RET(input_k_perm != nullptr, {});
617 auto is_k_transpese = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
618 MS_CHECK_TRUE_RET(is_k_transpese != nullptr, {});
619 auto k_transpose = VectorRef({is_k_transpese, input_k, input_k_perm});
620 // k reshape
621 auto reshape_k_input = std::make_shared<Var>();
622 MS_CHECK_TRUE_RET(reshape_k_input != nullptr, {});
623 auto is_k_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
624 MS_CHECK_TRUE_RET(is_k_reshape != nullptr, {});
625 auto reshape_k = VectorRef({is_k_reshape, k_transpose, reshape_k_input});
626 // k trans 2
627 auto input_k_trans_2_perm = std::make_shared<Var>();
628 MS_CHECK_TRUE_RET(input_k_trans_2_perm != nullptr, {});
629 auto is_k_transpese_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
630 MS_CHECK_TRUE_RET(is_k_transpese_2 != nullptr, {});
631 auto k_transpose_2 = VectorRef({is_k_transpese_2, reshape_k, input_k_trans_2_perm});
632
633 // v trans
634 auto input_v = std::make_shared<Var>();
635 MS_CHECK_TRUE_RET(input_v != nullptr, {});
636 auto input_v_perm = std::make_shared<Var>();
637 MS_CHECK_TRUE_RET(input_v_perm != nullptr, {});
638 auto is_v_transpese = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
639 MS_CHECK_TRUE_RET(is_v_transpese != nullptr, {});
640 auto v_transpose = VectorRef({is_v_transpese, input_v, input_v_perm});
641 // v reshape
642 auto reshape_v_input = std::make_shared<Var>();
643 MS_CHECK_TRUE_RET(reshape_v_input != nullptr, {});
644 auto is_v_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
645 MS_CHECK_TRUE_RET(is_v_reshape != nullptr, {});
646 auto reshape_v = VectorRef({is_v_reshape, v_transpose, reshape_v_input});
647
648 // // matmul 1
649 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
650 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
651 auto matmul_1 = VectorRef({is_matmul_1, reshape_q, k_transpose_2});
652 // mul
653 auto is_mul_param = std::make_shared<Var>();
654 MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
655 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
656 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
657 auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
658
659 // cast
660 auto is_cast_1_param = std::make_shared<Var>();
661 MS_CHECK_TRUE_RET(is_cast_1_param != nullptr, {});
662 auto is_cast_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
663 MS_CHECK_TRUE_RET(is_cast_1 != nullptr, {});
664 auto cast_1 = VectorRef({is_cast_1, mul, is_cast_1_param});
665
666 // softmax
667 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
668 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
669 auto softmax = VectorRef({is_softmax, cast_1});
670 // cast
671 auto is_cast_param = std::make_shared<Var>();
672 MS_CHECK_TRUE_RET(is_cast_param != nullptr, {});
673 auto is_cast = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
674 MS_CHECK_TRUE_RET(is_cast != nullptr, {});
675 auto cast = VectorRef({is_cast, softmax, is_cast_param});
676 // matmul
677 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
678 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
679 auto matmul_2 = VectorRef({is_matmul_2, cast, reshape_v});
680
681 // output reshape to four dims
682 auto reshape_o_2 = std::make_shared<Var>();
683 MS_CHECK_TRUE_RET(reshape_o_2 != nullptr, {});
684 auto is_reshape_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
685 MS_CHECK_TRUE_RET(is_reshape_o != nullptr, {});
686 auto reshape_o = VectorRef({is_reshape_o, matmul_2, reshape_o_2});
687 return reshape_o;
688 }
689
DefineFlashAttentionPatternForSDBNSD() const690 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForSDBNSD() const {
691 // Q reshape
692 auto reshape_q_input_1 = std::make_shared<Var>(); // input Q
693 MS_CHECK_TRUE_RET(reshape_q_input_1 != nullptr, {});
694 auto reshape_q_input_2 = std::make_shared<Var>();
695 MS_CHECK_TRUE_RET(reshape_q_input_2 != nullptr, {});
696 auto is_reshape_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
697 MS_CHECK_TRUE_RET(is_reshape_q != nullptr, {});
698 auto reshape_q = VectorRef({is_reshape_q, reshape_q_input_1, reshape_q_input_2});
699 // K reshape
700 auto reshape_k_input_1 = std::make_shared<Var>(); // input K
701 MS_CHECK_TRUE_RET(reshape_k_input_1 != nullptr, {});
702 auto reshape_k_input_2 = std::make_shared<Var>();
703 MS_CHECK_TRUE_RET(reshape_k_input_2 != nullptr, {});
704 auto is_reshape_k = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
705 MS_CHECK_TRUE_RET(is_reshape_k != nullptr, {});
706 auto reshape_k = VectorRef({is_reshape_k, reshape_k_input_1, reshape_k_input_2});
707 // transpose
708 auto is_transpose_param = std::make_shared<CondVar>(IsParamNode);
709 MS_CHECK_TRUE_RET(is_transpose_param != nullptr, {});
710 auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
711 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
712 auto transpose = VectorRef({is_transpose, reshape_k, is_transpose_param});
713 // matmul 1
714 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
715 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
716 auto matmul_1 = VectorRef({is_matmul_1, reshape_q, transpose});
717 // mul
718 auto is_mul_param = std::make_shared<CondVar>(IsParamNode);
719 MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
720 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
721 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
722 auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
723 // softmax
724 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
725 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
726 auto softmax = VectorRef({is_softmax, mul});
727 // cast
728 auto is_cast_param = std::make_shared<CondVar>(IsParamNode);
729 MS_CHECK_TRUE_RET(is_cast_param != nullptr, {});
730 auto is_cast = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
731 MS_CHECK_TRUE_RET(is_cast != nullptr, {});
732 auto cast = VectorRef({is_cast, softmax, is_cast_param});
733 // V reshape
734 auto reshape_v_input_1 = std::make_shared<Var>(); // input V
735 MS_CHECK_TRUE_RET(reshape_v_input_1 != nullptr, {});
736 auto reshape_v_input_2 = std::make_shared<Var>();
737 MS_CHECK_TRUE_RET(reshape_v_input_2 != nullptr, {});
738 auto is_reshape_v = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
739 MS_CHECK_TRUE_RET(is_reshape_v != nullptr, {});
740 auto reshape_v = VectorRef({is_reshape_v, reshape_v_input_1, reshape_v_input_2});
741 // matmul
742 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
743 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
744 auto matmul_2 = VectorRef({is_matmul_2, cast, reshape_v});
745 // output reshape to four dims
746 auto reshape_o_2 = std::make_shared<Var>();
747 MS_CHECK_TRUE_RET(reshape_o_2 != nullptr, {});
748 auto is_reshape_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
749 MS_CHECK_TRUE_RET(is_reshape_o != nullptr, {});
750 auto reshape_o = VectorRef({is_reshape_o, matmul_2, reshape_o_2});
751 return reshape_o;
752 }
753
DefineFlashAttentionPatternForSDBSH() const754 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForSDBSH() const {
755 // Q: three dim input reshape to four dims
756 auto input_q_reshape_param_1 = std::make_shared<Var>();
757 MS_CHECK_TRUE_RET(input_q_reshape_param_1 != nullptr, {});
758 auto input_q_reshape_param_2 = std::make_shared<Var>();
759 MS_CHECK_TRUE_RET(input_q_reshape_param_2 != nullptr, {});
760 auto is_input_q_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
761 MS_CHECK_TRUE_RET(is_input_q_reshape != nullptr, {});
762 auto input_q_reshape = VectorRef({is_input_q_reshape, input_q_reshape_param_1, input_q_reshape_param_2});
763 // transpose
764 auto is_input_q_transpose_param = std::make_shared<Var>();
765 MS_CHECK_TRUE_RET(is_input_q_transpose_param != nullptr, {});
766 auto is_input_q_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
767 MS_CHECK_TRUE_RET(is_input_q_transpose != nullptr, {});
768 auto input_q_transpose = VectorRef({is_input_q_transpose, input_q_reshape, is_input_q_transpose_param});
769 // Q reshape
770 auto reshape_q_input_2 = std::make_shared<Var>();
771 MS_CHECK_TRUE_RET(reshape_q_input_2 != nullptr, {});
772 auto is_reshape_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
773 MS_CHECK_TRUE_RET(is_reshape_q != nullptr, {});
774 auto reshape_q = VectorRef({is_reshape_q, input_q_transpose, reshape_q_input_2});
775
776 // K: three dim input reshape to four dims
777 auto input_k_reshape_param_1 = std::make_shared<Var>();
778 MS_CHECK_TRUE_RET(input_k_reshape_param_1 != nullptr, {});
779 auto input_k_reshape_param_2 = std::make_shared<Var>();
780 MS_CHECK_TRUE_RET(input_k_reshape_param_2 != nullptr, {});
781 auto is_input_k_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
782 MS_CHECK_TRUE_RET(is_input_k_reshape != nullptr, {});
783 auto input_k_reshape = VectorRef({is_input_k_reshape, input_k_reshape_param_1, input_k_reshape_param_2});
784 // transpose
785 auto is_input_k_transpose_param = std::make_shared<Var>();
786 MS_CHECK_TRUE_RET(is_input_k_transpose_param != nullptr, {});
787 auto is_input_k_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
788 MS_CHECK_TRUE_RET(is_input_k_transpose != nullptr, {});
789 auto input_k_transpose = VectorRef({is_input_k_transpose, input_k_reshape, is_input_k_transpose_param});
790 // K reshape
791 auto reshape_k_input_2 = std::make_shared<Var>();
792 MS_CHECK_TRUE_RET(reshape_k_input_2 != nullptr, {});
793 auto is_reshape_k = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
794 MS_CHECK_TRUE_RET(is_reshape_k != nullptr, {});
795 auto reshape_k = VectorRef({is_reshape_k, input_k_transpose, reshape_k_input_2});
796 // transpose
797 auto is_transpose_param = std::make_shared<CondVar>(IsParamNode);
798 MS_CHECK_TRUE_RET(is_transpose_param != nullptr, {});
799 auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
800 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
801 auto transpose = VectorRef({is_transpose, reshape_k, is_transpose_param});
802 // matmul 1
803 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
804 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
805 auto matmul_1 = VectorRef({is_matmul_1, reshape_q, transpose});
806 // mul
807 auto is_mul_param = std::make_shared<CondVar>(IsParamNode);
808 MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
809 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
810 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
811 auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
812 // softmax
813 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
814 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
815 auto softmax = VectorRef({is_softmax, mul});
816 // cast
817 auto is_cast_param = std::make_shared<CondVar>(IsParamNode);
818 MS_CHECK_TRUE_RET(is_cast_param != nullptr, {});
819 auto is_cast = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
820 MS_CHECK_TRUE_RET(is_cast != nullptr, {});
821 auto cast = VectorRef({is_cast, softmax, is_cast_param});
822
823 // V: three dim input reshape to four dims
824 auto input_v_reshape_param_1 = std::make_shared<Var>();
825 MS_CHECK_TRUE_RET(input_v_reshape_param_1 != nullptr, {});
826 auto input_v_reshape_param_2 = std::make_shared<Var>();
827 MS_CHECK_TRUE_RET(input_v_reshape_param_2 != nullptr, {});
828 auto is_input_v_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
829 MS_CHECK_TRUE_RET(is_input_v_reshape != nullptr, {});
830 auto input_v_reshape = VectorRef({is_input_v_reshape, input_v_reshape_param_1, input_v_reshape_param_2});
831 // transpose
832 auto is_input_v_transpose_param = std::make_shared<CondVar>(IsParamNode);
833 MS_CHECK_TRUE_RET(is_input_v_transpose_param != nullptr, {});
834 auto is_input_v_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
835 MS_CHECK_TRUE_RET(is_input_v_transpose != nullptr, {});
836 auto input_v_transpose = VectorRef({is_input_v_transpose, input_v_reshape, is_input_v_transpose_param});
837 // V reshape
838 auto reshape_v_input_2 = std::make_shared<Var>();
839 MS_CHECK_TRUE_RET(reshape_v_input_2 != nullptr, {});
840 auto is_reshape_v = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
841 MS_CHECK_TRUE_RET(is_reshape_v != nullptr, {});
842 auto reshape_v = VectorRef({is_reshape_v, input_v_transpose, reshape_v_input_2});
843 // matmul
844 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
845 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
846 auto matmul_2 = VectorRef({is_matmul_2, cast, reshape_v});
847 // output reshape to four dims
848 auto reshape_o_2 = std::make_shared<Var>();
849 MS_CHECK_TRUE_RET(reshape_o_2 != nullptr, {});
850 auto is_reshape_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
851 MS_CHECK_TRUE_RET(is_reshape_o != nullptr, {});
852 auto reshape_o = VectorRef({is_reshape_o, matmul_2, reshape_o_2});
853 // output transpose
854 auto is_transpose_o_param = std::make_shared<CondVar>(IsParamNode);
855 MS_CHECK_TRUE_RET(is_transpose_o_param != nullptr, {});
856 auto is_transpose_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
857 MS_CHECK_TRUE_RET(is_transpose_o != nullptr, {});
858 auto transpose_o = VectorRef({is_transpose_o, reshape_o, is_transpose_o_param});
859 // output reshape to three dims
860 auto reshape_o2_2 = std::make_shared<Var>();
861 MS_CHECK_TRUE_RET(reshape_o2_2 != nullptr, {});
862 auto is_reshape_o2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
863 MS_CHECK_TRUE_RET(is_reshape_o2 != nullptr, {});
864 auto reshape_o2 = VectorRef({is_reshape_o2, transpose_o, reshape_o2_2});
865 return reshape_o2;
866 }
867
DefineFlashAttentionPatternForSDPreMul() const868 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForSDPreMul() const {
869 // Q
870 auto q_input = std::make_shared<Var>(); // input Q
871 MS_CHECK_TRUE_RET(q_input != nullptr, {});
872 // mul
873 auto mul_q_val = std::make_shared<Var>();
874 MS_CHECK_TRUE_RET(mul_q_val != nullptr, {});
875 auto is_mul_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
876 MS_CHECK_TRUE_RET(is_mul_q != nullptr, {});
877 auto mul_q = VectorRef({is_mul_q, q_input, mul_q_val});
878
879 // K
880 auto k_input = std::make_shared<Var>(); // input Q
881 MS_CHECK_TRUE_RET(k_input != nullptr, {});
882 // mul
883 auto mul_k_val = std::make_shared<Var>();
884 MS_CHECK_TRUE_RET(mul_k_val != nullptr, {});
885 auto is_mul_k = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
886 MS_CHECK_TRUE_RET(is_mul_k != nullptr, {});
887 auto mul_k = VectorRef({is_mul_k, k_input, mul_k_val});
888
889 // matmul 1
890 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
891 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
892 auto matmul_1 = VectorRef({is_matmul_1, mul_q, mul_k});
893
894 // softmax
895 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
896 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
897 auto softmax = VectorRef({is_softmax, matmul_1});
898
899 // matmul
900 auto mul_v = std::make_shared<Var>();
901 MS_CHECK_TRUE_RET(mul_v != nullptr, {});
902 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
903 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
904 auto matmul_2 = VectorRef({is_matmul_2, softmax, mul_v});
905
906 auto output_trans_input_2 = std::make_shared<Var>();
907 MS_CHECK_TRUE_RET(output_trans_input_2 != nullptr, {});
908 auto is_trans_output = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
909 MS_CHECK_TRUE_RET(is_trans_output != nullptr, {});
910 auto transpose_output = VectorRef({is_trans_output, matmul_2, output_trans_input_2});
911 MS_CHECK_TRUE_RET(transpose_output != nullptr, {});
912
913 auto output_reshape_input_2 = std::make_shared<Var>();
914 MS_CHECK_TRUE_RET(output_reshape_input_2 != nullptr, {});
915 auto is_reshape_output = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
916 MS_CHECK_TRUE_RET(is_reshape_output != nullptr, {});
917 auto output_reshape = VectorRef({is_reshape_output, transpose_output, output_reshape_input_2});
918 MS_CHECK_TRUE_RET(output_reshape != nullptr, {});
919 return output_reshape;
920 }
921
DefineFlashAttentionPatternForSDWithoutCast() const922 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForSDWithoutCast() const {
923 // Q: three dim input reshape to four dims
924 auto input_q_reshape_param_1 = std::make_shared<Var>();
925 MS_CHECK_TRUE_RET(input_q_reshape_param_1 != nullptr, {});
926 auto input_q_reshape_param_2 = std::make_shared<Var>();
927 MS_CHECK_TRUE_RET(input_q_reshape_param_2 != nullptr, {});
928 auto is_input_q_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
929 MS_CHECK_TRUE_RET(is_input_q_reshape != nullptr, {});
930 auto input_q_reshape = VectorRef({is_input_q_reshape, input_q_reshape_param_1, input_q_reshape_param_2});
931 // transpose
932 auto is_input_q_transpose_param = std::make_shared<Var>();
933 MS_CHECK_TRUE_RET(is_input_q_transpose_param != nullptr, {});
934 auto is_input_q_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
935 MS_CHECK_TRUE_RET(is_input_q_transpose != nullptr, {});
936 auto input_q_transpose = VectorRef({is_input_q_transpose, input_q_reshape, is_input_q_transpose_param});
937 // Q reshape
938 auto reshape_q_input_2 = std::make_shared<Var>();
939 MS_CHECK_TRUE_RET(reshape_q_input_2 != nullptr, {});
940 auto is_reshape_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
941 MS_CHECK_TRUE_RET(is_reshape_q != nullptr, {});
942 auto reshape_q = VectorRef({is_reshape_q, input_q_transpose, reshape_q_input_2});
943
944 // K: three dim input reshape to four dims
945 auto input_k_reshape_param_1 = std::make_shared<Var>();
946 MS_CHECK_TRUE_RET(input_k_reshape_param_1 != nullptr, {});
947 auto input_k_reshape_param_2 = std::make_shared<Var>();
948 MS_CHECK_TRUE_RET(input_k_reshape_param_2 != nullptr, {});
949 auto is_input_k_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
950 MS_CHECK_TRUE_RET(is_input_k_reshape != nullptr, {});
951 auto input_k_reshape = VectorRef({is_input_k_reshape, input_k_reshape_param_1, input_k_reshape_param_2});
952 // transpose
953 auto is_input_k_transpose_param = std::make_shared<Var>();
954 MS_CHECK_TRUE_RET(is_input_k_transpose_param != nullptr, {});
955 auto is_input_k_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
956 MS_CHECK_TRUE_RET(is_input_k_transpose != nullptr, {});
957 auto input_k_transpose = VectorRef({is_input_k_transpose, input_k_reshape, is_input_k_transpose_param});
958 // K reshape
959 auto reshape_k_input_2 = std::make_shared<Var>();
960 MS_CHECK_TRUE_RET(reshape_k_input_2 != nullptr, {});
961 auto is_reshape_k = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
962 MS_CHECK_TRUE_RET(is_reshape_k != nullptr, {});
963 auto reshape_k = VectorRef({is_reshape_k, input_k_transpose, reshape_k_input_2});
964 // transpose
965 auto is_transpose_param = std::make_shared<CondVar>(IsParamNode);
966 MS_CHECK_TRUE_RET(is_transpose_param != nullptr, {});
967 auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
968 MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
969 auto transpose = VectorRef({is_transpose, reshape_k, is_transpose_param});
970 // matmul 1
971 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
972 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
973 auto matmul_1 = VectorRef({is_matmul_1, reshape_q, transpose});
974 // mul
975 auto is_mul_param = std::make_shared<CondVar>(IsParamNode);
976 MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
977 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
978 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
979 auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
980 // softmax
981 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
982 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
983 auto softmax = VectorRef({is_softmax, mul});
984
985 // V: three dim input reshape to four dims
986 auto input_v_reshape_param_1 = std::make_shared<Var>();
987 MS_CHECK_TRUE_RET(input_v_reshape_param_1 != nullptr, {});
988 auto input_v_reshape_param_2 = std::make_shared<Var>();
989 MS_CHECK_TRUE_RET(input_v_reshape_param_2 != nullptr, {});
990 auto is_input_v_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
991 MS_CHECK_TRUE_RET(is_input_v_reshape != nullptr, {});
992 auto input_v_reshape = VectorRef({is_input_v_reshape, input_v_reshape_param_1, input_v_reshape_param_2});
993 // transpose
994 auto is_input_v_transpose_param = std::make_shared<CondVar>(IsParamNode);
995 MS_CHECK_TRUE_RET(is_input_v_transpose_param != nullptr, {});
996 auto is_input_v_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
997 MS_CHECK_TRUE_RET(is_input_v_transpose != nullptr, {});
998 auto input_v_transpose = VectorRef({is_input_v_transpose, input_v_reshape, is_input_v_transpose_param});
999 // V reshape
1000 auto reshape_v_input_2 = std::make_shared<Var>();
1001 MS_CHECK_TRUE_RET(reshape_v_input_2 != nullptr, {});
1002 auto is_reshape_v = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1003 MS_CHECK_TRUE_RET(is_reshape_v != nullptr, {});
1004 auto reshape_v = VectorRef({is_reshape_v, input_v_transpose, reshape_v_input_2});
1005 // matmul
1006 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
1007 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1008 auto matmul_2 = VectorRef({is_matmul_2, softmax, reshape_v});
1009 // output reshape to four dims
1010 auto reshape_o_2 = std::make_shared<Var>();
1011 MS_CHECK_TRUE_RET(reshape_o_2 != nullptr, {});
1012 auto is_reshape_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1013 MS_CHECK_TRUE_RET(is_reshape_o != nullptr, {});
1014 auto reshape_o = VectorRef({is_reshape_o, matmul_2, reshape_o_2});
1015 // output transpose
1016 auto is_transpose_o_param = std::make_shared<CondVar>(IsParamNode);
1017 MS_CHECK_TRUE_RET(is_transpose_o_param != nullptr, {});
1018 auto is_transpose_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
1019 MS_CHECK_TRUE_RET(is_transpose_o != nullptr, {});
1020 auto transpose_o = VectorRef({is_transpose_o, reshape_o, is_transpose_o_param});
1021 // output reshape to three dims
1022 auto reshape_o2_2 = std::make_shared<Var>();
1023 MS_CHECK_TRUE_RET(reshape_o2_2 != nullptr, {});
1024 auto is_reshape_o2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1025 MS_CHECK_TRUE_RET(is_reshape_o2 != nullptr, {});
1026 auto reshape_o2 = VectorRef({is_reshape_o2, transpose_o, reshape_o2_2});
1027 return reshape_o2;
1028 }
1029
DefineFlashAttentionPatternForPanGu() const1030 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForPanGu() const {
1031 // q div
1032 auto q = std::make_shared<Var>(); // input Q
1033 MS_CHECK_TRUE_RET(q != nullptr, {});
1034 auto is_div_q_param = std::make_shared<Var>();
1035 MS_CHECK_TRUE_RET(is_div_q_param != nullptr, {});
1036 auto is_div_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRealDiv>);
1037 MS_CHECK_TRUE_RET(is_div_q != nullptr, {});
1038 auto div_q = VectorRef({is_div_q, q, is_div_q_param});
1039 // matmul 1
1040 auto k = std::make_shared<Var>(); // input K
1041 MS_CHECK_TRUE_RET(k != nullptr, {});
1042 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1043 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
1044 auto matmul_1 = VectorRef({is_matmul_1, div_q, k});
1045 // cast 1
1046 auto is_cast_1_param = std::make_shared<Var>();
1047 MS_CHECK_TRUE_RET(is_cast_1_param != nullptr, {});
1048 auto is_cast_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
1049 MS_CHECK_TRUE_RET(is_cast_1 != nullptr, {});
1050 auto cast_1 = VectorRef({is_cast_1, matmul_1, is_cast_1_param});
1051 // ===== attention mask =====
1052 // sub
1053 auto atten_mask = std::make_shared<Var>();
1054 MS_CHECK_TRUE_RET(atten_mask != nullptr, {});
1055 // mul
1056 auto is_mask_mul_param = std::make_shared<Var>();
1057 MS_CHECK_TRUE_RET(is_mask_mul_param != nullptr, {});
1058 auto is_mask_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1059 MS_CHECK_TRUE_RET(is_mask_mul != nullptr, {});
1060 auto mask_mul = VectorRef({is_mask_mul, atten_mask, is_mask_mul_param});
1061 // ===== end attention mask =====
1062 // add
1063 auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAdd>);
1064 MS_CHECK_TRUE_RET(is_add != nullptr, {});
1065 auto add = VectorRef({is_add, mask_mul, cast_1});
1066 // softmax
1067 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
1068 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
1069 auto softmax = VectorRef({is_softmax, add});
1070 // cast 2
1071 auto is_cast_2_param = std::make_shared<Var>();
1072 MS_CHECK_TRUE_RET(is_cast_2_param != nullptr, {});
1073 auto is_cast_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
1074 MS_CHECK_TRUE_RET(is_cast_2 != nullptr, {});
1075 auto cast_2 = VectorRef({is_cast_2, softmax, is_cast_2_param});
1076
1077 // matmul 2
1078 auto v = std::make_shared<Var>(); // input V
1079 MS_CHECK_TRUE_RET(v != nullptr, {});
1080 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1081 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1082 auto matmul_2 = VectorRef({is_matmul_2, cast_2, v});
1083 return matmul_2;
1084 }
1085
DefineFlashAttentionPatternForLLAMAPatternV1() const1086 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForLLAMAPatternV1() const {
1087 // matmul 1
1088 auto matmul_1_q_input = std::make_shared<Var>(); // input Q
1089 MS_CHECK_TRUE_RET(matmul_1_q_input != nullptr, {});
1090 auto matmul_1_k_input = std::make_shared<Var>(); // input K
1091 MS_CHECK_TRUE_RET(matmul_1_k_input != nullptr, {});
1092 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1093 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
1094 auto matmul_1 = VectorRef({is_matmul_1, matmul_1_q_input, matmul_1_k_input});
1095 // mul
1096 auto is_mul_param = std::make_shared<Var>();
1097 MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
1098 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1099 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
1100 auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
1101 // ===== attention mask =====
1102 // sub
1103 auto sub_mask_input_1 = std::make_shared<Var>(); // input attention mask
1104 MS_CHECK_TRUE_RET(sub_mask_input_1 != nullptr, {});
1105 // mul
1106 auto is_mask_mul_param = std::make_shared<Var>();
1107 MS_CHECK_TRUE_RET(is_mask_mul_param != nullptr, {});
1108 auto is_mask_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1109 MS_CHECK_TRUE_RET(is_mask_mul != nullptr, {});
1110 auto mask_mul = VectorRef({is_mask_mul, sub_mask_input_1, is_mask_mul_param});
1111 // ===== end attention mask =====
1112 // add
1113 auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAdd>);
1114 MS_CHECK_TRUE_RET(is_add != nullptr, {});
1115 auto add = VectorRef({is_add, mask_mul, mul});
1116 // cast 1
1117 auto is_cast_1_param = std::make_shared<Var>();
1118 MS_CHECK_TRUE_RET(is_cast_1_param != nullptr, {});
1119 auto is_cast_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
1120 MS_CHECK_TRUE_RET(is_cast_1 != nullptr, {});
1121 auto cast_1 = VectorRef({is_cast_1, add, is_cast_1_param});
1122 // softmax
1123 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
1124 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
1125 auto softmax = VectorRef({is_softmax, cast_1});
1126 // cast 2
1127 auto is_cast_2_param = std::make_shared<Var>();
1128 MS_CHECK_TRUE_RET(is_cast_2_param != nullptr, {});
1129 auto is_cast_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
1130 MS_CHECK_TRUE_RET(is_cast_2 != nullptr, {});
1131 auto cast_2 = VectorRef({is_cast_2, softmax, is_cast_2_param});
1132 // matmul
1133 auto v_input = std::make_shared<Var>(); // input V
1134 MS_CHECK_TRUE_RET(v_input != nullptr, {});
1135 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1136 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1137 auto matmul_2 = VectorRef({is_matmul_2, cast_2, v_input});
1138 return matmul_2;
1139 }
1140
DefineFlashAttentionPatternForLLAMAPatternV2() const1141 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForLLAMAPatternV2() const {
1142 // matmul 1
1143 auto matmul_1_q_input = std::make_shared<Var>(); // input Q
1144 MS_CHECK_TRUE_RET(matmul_1_q_input != nullptr, {});
1145 auto matmul_1_k_input = std::make_shared<Var>(); // input K
1146 MS_CHECK_TRUE_RET(matmul_1_k_input != nullptr, {});
1147 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1148 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
1149 auto matmul_1 = VectorRef({is_matmul_1, matmul_1_q_input, matmul_1_k_input});
1150 // mul
1151 auto is_mul_param = std::make_shared<Var>();
1152 MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
1153 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1154 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
1155 auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
1156 // ===== attention mask =====
1157 // sub
1158 auto sub_mask_input_1 = std::make_shared<Var>(); // input attention mask
1159 MS_CHECK_TRUE_RET(sub_mask_input_1 != nullptr, {});
1160 // mul
1161 auto is_mask_mul_param = std::make_shared<Var>();
1162 MS_CHECK_TRUE_RET(is_mask_mul_param != nullptr, {});
1163 auto is_mask_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1164 MS_CHECK_TRUE_RET(is_mask_mul != nullptr, {});
1165 auto mask_mul = VectorRef({is_mask_mul, sub_mask_input_1, is_mask_mul_param});
1166 // ===== end attention mask =====
1167 // add
1168 auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAdd>);
1169 MS_CHECK_TRUE_RET(is_add != nullptr, {});
1170 auto add = VectorRef({is_add, mask_mul, mul});
1171 // softmax
1172 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
1173 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
1174 auto softmax = VectorRef({is_softmax, add});
1175 // matmul
1176 auto v_input = std::make_shared<Var>(); // input V
1177 MS_CHECK_TRUE_RET(v_input != nullptr, {});
1178 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1179 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1180 auto matmul_2 = VectorRef({is_matmul_2, softmax, v_input});
1181 return matmul_2;
1182 }
1183
DefineFlashAttentionPatternForBaiChuan() const1184 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForBaiChuan() const {
1185 // matmul 1
1186 auto matmul_1_q_input = std::make_shared<Var>(); // input Q
1187 MS_CHECK_TRUE_RET(matmul_1_q_input != nullptr, {});
1188 auto matmul_1_k_input = std::make_shared<Var>(); // input K
1189 MS_CHECK_TRUE_RET(matmul_1_k_input != nullptr, {});
1190 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1191 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
1192 auto matmul_1 = VectorRef({is_matmul_1, matmul_1_q_input, matmul_1_k_input});
1193 // mul
1194 auto is_mul_param = std::make_shared<Var>();
1195 MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
1196 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1197 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
1198 auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
1199 // ===== attention mask =====
1200 // mul
1201 auto is_mask_mul_param1 = std::make_shared<Var>();
1202 MS_CHECK_TRUE_RET(is_mask_mul_param1 != nullptr, {});
1203 auto is_mask_mul_param2 = std::make_shared<Var>();
1204 MS_CHECK_TRUE_RET(is_mask_mul_param2 != nullptr, {});
1205 auto is_mask_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1206 MS_CHECK_TRUE_RET(is_mask_mul != nullptr, {});
1207 auto mask_mul = VectorRef({is_mask_mul, is_mask_mul_param1, is_mask_mul_param2});
1208 // ===== end attention mask =====
1209 // add
1210 auto is_add_param = std::make_shared<Var>();
1211 MS_CHECK_TRUE_RET(is_add_param != nullptr, {});
1212 auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAdd>);
1213 MS_CHECK_TRUE_RET(is_add != nullptr, {});
1214 auto add = VectorRef({is_add, mul, is_add_param});
1215 // add for mask
1216 auto is_add_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAdd>);
1217 MS_CHECK_TRUE_RET(is_add_2 != nullptr, {});
1218 auto add_2 = VectorRef({is_add, mask_mul, add});
1219 // softmax
1220 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
1221 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
1222 auto softmax = VectorRef({is_softmax, add_2});
1223 // matmul
1224 auto v_input = std::make_shared<Var>(); // input V
1225 MS_CHECK_TRUE_RET(v_input != nullptr, {});
1226 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1227 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1228 auto matmul_2 = VectorRef({is_matmul_2, softmax, v_input});
1229 return matmul_2;
1230 }
1231
DefineFlashAttentionPatternForSDEinsum() const1232 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForSDEinsum() const {
1233 // Q reshape
1234 auto input_q_reshape_param_1 = std::make_shared<Var>();
1235 MS_CHECK_TRUE_RET(input_q_reshape_param_1 != nullptr, {});
1236 auto input_q_reshape_param_2 = std::make_shared<Var>();
1237 MS_CHECK_TRUE_RET(input_q_reshape_param_2 != nullptr, {});
1238 auto is_input_q_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1239 MS_CHECK_TRUE_RET(is_input_q_reshape != nullptr, {});
1240 auto q_reshape = VectorRef({is_input_q_reshape, input_q_reshape_param_1, input_q_reshape_param_2});
1241
1242 // K reshape
1243 auto input_k_reshape_param_1 = std::make_shared<Var>();
1244 MS_CHECK_TRUE_RET(input_k_reshape_param_1 != nullptr, {});
1245 auto input_k_reshape_param_2 = std::make_shared<Var>();
1246 MS_CHECK_TRUE_RET(input_k_reshape_param_2 != nullptr, {});
1247 auto is_input_k_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1248 MS_CHECK_TRUE_RET(is_input_k_reshape != nullptr, {});
1249 auto k_reshape = VectorRef({is_input_k_reshape, input_k_reshape_param_1, input_k_reshape_param_2});
1250
1251 // matmul 1, einsum is replaced in onnx_einsum_adjust.cc
1252 auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
1253 MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
1254 auto matuml_1 = VectorRef({is_matmul_1, q_reshape, k_reshape});
1255 // mul
1256 auto is_mul_param = std::make_shared<CondVar>(IsParamNode);
1257 MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
1258 auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
1259 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
1260 auto mul = VectorRef({is_mul, matuml_1, is_mul_param});
1261 // softmax
1262 auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
1263 MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
1264 auto softmax = VectorRef({is_softmax, mul});
1265
1266 // V reshape
1267 auto input_v_reshape_param_1 = std::make_shared<Var>();
1268 MS_CHECK_TRUE_RET(input_v_reshape_param_1 != nullptr, {});
1269 auto input_v_reshape_param_2 = std::make_shared<Var>();
1270 MS_CHECK_TRUE_RET(input_v_reshape_param_2 != nullptr, {});
1271 auto is_input_v_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1272 MS_CHECK_TRUE_RET(is_input_v_reshape != nullptr, {});
1273 auto v_reshape = VectorRef({is_input_v_reshape, input_v_reshape_param_1, input_v_reshape_param_2});
1274
1275 // matmul 2
1276 auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
1277 MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1278 auto matmul_2 = VectorRef({is_matmul_2, softmax, v_reshape});
1279 // output reshape to four dims
1280 auto reshape_o_2 = std::make_shared<Var>();
1281 MS_CHECK_TRUE_RET(reshape_o_2 != nullptr, {});
1282 auto is_reshape_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1283 MS_CHECK_TRUE_RET(is_reshape_o != nullptr, {});
1284 auto reshape_o = VectorRef({is_reshape_o, matmul_2, reshape_o_2});
1285 // output transpose
1286 auto is_transpose_o_param = std::make_shared<CondVar>(IsParamNode);
1287 MS_CHECK_TRUE_RET(is_transpose_o_param != nullptr, {});
1288 auto is_transpose_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
1289 MS_CHECK_TRUE_RET(is_transpose_o != nullptr, {});
1290 auto transpose_o = VectorRef({is_transpose_o, reshape_o, is_transpose_o_param});
1291 // output reshape to three dims
1292 auto reshape_o2_2 = std::make_shared<Var>();
1293 MS_CHECK_TRUE_RET(reshape_o2_2 != nullptr, {});
1294 auto is_reshape_o2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1295 MS_CHECK_TRUE_RET(is_reshape_o2 != nullptr, {});
1296 auto reshape_o2 = VectorRef({is_reshape_o2, transpose_o, reshape_o2_2});
1297 return reshape_o2;
1298 }
1299
CreatePromptFlashAttentionCnodeForBNSD(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q,const AnfNodePtr & k,const AnfNodePtr & v,const AnfNodePtr & atten_mask,int64_t num_heads,int64_t next_token,float scale_value,const std::shared_ptr<FlashAttentionParm> & fa_parm,int64_t num_key_value_heads) const1300 CNodePtr FlashAttentionFusion::CreatePromptFlashAttentionCnodeForBNSD(
1301 const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
1302 const AnfNodePtr &atten_mask, int64_t num_heads, int64_t next_token, float scale_value,
1303 const std::shared_ptr<FlashAttentionParm> &fa_parm, int64_t num_key_value_heads) const {
1304 MS_LOG(INFO) << "CreatePromptFlashAttentionCnodeForBNSD";
1305 MS_LOG(INFO) << "num heads: " << num_heads << ", input layout: BNSD, next tokens: " << next_token
1306 << ", scale value: " << scale_value << ", num_key_value_heads: " << num_key_value_heads;
1307 MS_LOG(INFO) << "q name: " << q->fullname_with_scope() << ", k name: " << k->fullname_with_scope()
1308 << ", v name: " << v->fullname_with_scope();
1309 if (num_heads < 0 || scale_value < 0 || next_token < 0 || num_key_value_heads < 0) {
1310 MS_LOG(WARNING) << "shape is invalid";
1311 return nullptr;
1312 }
1313 if (fa_parm == nullptr) {
1314 MS_LOG(WARNING) << "FA parameter is null, please check.";
1315 return nullptr;
1316 }
1317 // create op
1318 auto prompt_flash_attention_prim = std::make_shared<ops::PromptFlashAttention>();
1319 if (prompt_flash_attention_prim == nullptr) {
1320 MS_LOG(ERROR) << "new prompt flash attention prim failed.";
1321 return nullptr;
1322 }
1323 // add attr
1324 prompt_flash_attention_prim->AddAttr("num_heads", api::MakeValue(num_heads));
1325 prompt_flash_attention_prim->AddAttr("input_layout", api::MakeValue("BNSD"));
1326 prompt_flash_attention_prim->AddAttr("next_tokens", api::MakeValue(next_token));
1327 prompt_flash_attention_prim->AddAttr("scale_value", api::MakeValue(scale_value));
1328 prompt_flash_attention_prim->AddAttr("num_key_value_heads", api::MakeValue(num_key_value_heads));
1329 prompt_flash_attention_prim->AddAttr("inner_precise", api::MakeValue(fa_parm->inner_precise));
1330 prompt_flash_attention_prim->AddAttr("sparse_mode", api::MakeValue(fa_parm->sparse_mode));
1331
1332 auto fa_prim_c = prompt_flash_attention_prim->GetPrim();
1333 if (fa_prim_c == nullptr) {
1334 MS_LOG(ERROR) << "fa_prim_c is nullptr.";
1335 return nullptr;
1336 }
1337 CNodePtr prompt_flash_attention_cnode = nullptr;
1338 if (atten_mask != nullptr) {
1339 prompt_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v, atten_mask});
1340 } else {
1341 prompt_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v});
1342 }
1343 if (prompt_flash_attention_cnode == nullptr) {
1344 MS_LOG(ERROR) << "new cnode failed.";
1345 return nullptr;
1346 }
1347 prompt_flash_attention_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_prompt_flash_attention_bnsd");
1348 if (node->abstract() != nullptr) {
1349 prompt_flash_attention_cnode->set_abstract(node->abstract()->Clone());
1350 }
1351 MS_LOG(INFO) << "create PromptFlashAttention success.";
1352 return prompt_flash_attention_cnode;
1353 }
1354
CreatePromptFlashAttentionCnodeForBNSDWithPse(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q,const AnfNodePtr & k,const AnfNodePtr & v,const AnfNodePtr & atten_mask,const AnfNodePtr & pse,int64_t num_heads,int64_t next_token,float scale_value,const std::shared_ptr<FlashAttentionParm> & fa_parm,int64_t num_key_value_heads) const1355 CNodePtr FlashAttentionFusion::CreatePromptFlashAttentionCnodeForBNSDWithPse(
1356 const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
1357 const AnfNodePtr &atten_mask, const AnfNodePtr &pse, int64_t num_heads, int64_t next_token, float scale_value,
1358 const std::shared_ptr<FlashAttentionParm> &fa_parm, int64_t num_key_value_heads) const {
1359 MS_LOG(INFO) << "CreatePromptFlashAttentionCnodeForBNSD with pse";
1360 MS_LOG(INFO) << "num heads: " << num_heads << ", input layout: BNSD, next tokens: " << next_token
1361 << ", scale value: " << scale_value << ", num_key_value_heads: " << num_key_value_heads;
1362 MS_LOG(INFO) << "q name: " << q->fullname_with_scope() << ", k name: " << k->fullname_with_scope()
1363 << ", v name: " << v->fullname_with_scope() << ", pse name " << pse->fullname_with_scope();
1364 if (num_heads < 0 || scale_value < 0 || next_token < 0 || num_key_value_heads < 0) {
1365 MS_LOG(WARNING) << "shape is invalid";
1366 return nullptr;
1367 }
1368 if (fa_parm == nullptr) {
1369 MS_LOG(WARNING) << "FA parameter is null, please check";
1370 return nullptr;
1371 }
1372 // create op
1373 auto prompt_flash_attention_prim = std::make_shared<ops::PromptFlashAttention>();
1374 if (prompt_flash_attention_prim == nullptr) {
1375 MS_LOG(ERROR) << "new prompt flash attention prim failed.";
1376 return nullptr;
1377 }
1378 // add attr
1379 prompt_flash_attention_prim->AddAttr("num_heads", api::MakeValue(num_heads));
1380 prompt_flash_attention_prim->AddAttr("input_layout", api::MakeValue("BNSD"));
1381 prompt_flash_attention_prim->AddAttr("next_tokens", api::MakeValue(next_token));
1382 prompt_flash_attention_prim->AddAttr("scale_value", api::MakeValue(scale_value));
1383 prompt_flash_attention_prim->AddAttr("num_key_value_heads", api::MakeValue(num_key_value_heads));
1384 prompt_flash_attention_prim->AddAttr("inner_precise", api::MakeValue(fa_parm->inner_precise));
1385 prompt_flash_attention_prim->AddAttr("sparse_mode", api::MakeValue(fa_parm->sparse_mode));
1386
1387 auto fa_prim_c = prompt_flash_attention_prim->GetPrim();
1388 if (fa_prim_c == nullptr) {
1389 MS_LOG(ERROR) << "fa_prim_c is nullptr.";
1390 return nullptr;
1391 }
1392 CNodePtr prompt_flash_attention_cnode = nullptr;
1393 auto none_value_node_1 = NewValueNode(std::make_shared<None>());
1394 none_value_node_1->set_abstract(std::make_shared<abstract::AbstractNone>());
1395 auto none_value_node_2 = NewValueNode(std::make_shared<None>());
1396 none_value_node_2->set_abstract(std::make_shared<abstract::AbstractNone>());
1397
1398 if (atten_mask != nullptr) {
1399 prompt_flash_attention_cnode =
1400 func_graph->NewCNode(fa_prim_c, {q, k, v, atten_mask, none_value_node_1, none_value_node_2, pse});
1401 } else {
1402 auto none_value_node = NewValueNode(std::make_shared<None>());
1403 none_value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
1404 prompt_flash_attention_cnode =
1405 func_graph->NewCNode(fa_prim_c, {q, k, v, none_value_node, none_value_node_1, none_value_node_2, pse});
1406 }
1407 if (prompt_flash_attention_cnode == nullptr) {
1408 MS_LOG(ERROR) << "new prompt_flash_attention_cnode failed.";
1409 return nullptr;
1410 }
1411 prompt_flash_attention_cnode->set_fullname_with_scope(node->fullname_with_scope() +
1412 "_prompt_flash_attention_bnsd_pse");
1413 if (node->abstract() != nullptr) {
1414 prompt_flash_attention_cnode->set_abstract(node->abstract()->Clone());
1415 }
1416 MS_LOG(INFO) << "create PromptFlashAttention success.";
1417 return prompt_flash_attention_cnode;
1418 }
1419
CreatePromptFlashAttentionCnodeForBSH(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q,const AnfNodePtr & k,const AnfNodePtr & v,const AnfNodePtr & atten_mask,int64_t num_heads,int64_t next_token,float scale_value,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1420 CNodePtr FlashAttentionFusion::CreatePromptFlashAttentionCnodeForBSH(
1421 const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
1422 const AnfNodePtr &atten_mask, int64_t num_heads, int64_t next_token, float scale_value,
1423 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1424 MS_LOG(INFO) << "CreatePromptFlashAttentionCnodeForBSH";
1425 MS_LOG(INFO) << "input Q name: " << q->fullname_with_scope() << " ,input K name: " << k->fullname_with_scope()
1426 << " ,input V name: " << v->fullname_with_scope();
1427 // create op
1428 auto prompt_flash_attention_prim = std::make_shared<ops::PromptFlashAttention>();
1429 if (prompt_flash_attention_prim == nullptr) {
1430 MS_LOG(ERROR) << "incre_flash_attention_prim is nullptr.";
1431 return nullptr;
1432 }
1433 if (fa_parm == nullptr) {
1434 MS_LOG(WARNING) << "FA parameter is null, please check";
1435 return nullptr;
1436 }
1437 // add attr
1438 prompt_flash_attention_prim->AddAttr("num_heads", api::MakeValue(num_heads));
1439 prompt_flash_attention_prim->AddAttr("input_layout", api::MakeValue("BSH"));
1440 prompt_flash_attention_prim->AddAttr("next_tokens", api::MakeValue(next_token));
1441 prompt_flash_attention_prim->AddAttr("scale_value", api::MakeValue(scale_value));
1442 prompt_flash_attention_prim->AddAttr("num_key_value_heads", api::MakeValue(num_heads));
1443 prompt_flash_attention_prim->AddAttr("inner_precise", api::MakeValue(fa_parm->inner_precise));
1444 prompt_flash_attention_prim->AddAttr("sparse_mode", api::MakeValue(fa_parm->sparse_mode));
1445
1446 MS_LOG(INFO) << "num heads: " << num_heads << ", input layout: BSH, next tokens: " << next_token
1447 << ", scale value: " << scale_value;
1448 auto fa_prim_c = prompt_flash_attention_prim->GetPrim();
1449 CNodePtr prompt_flash_attention_cnode = nullptr;
1450 if (atten_mask != nullptr) {
1451 prompt_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v, atten_mask});
1452 } else {
1453 prompt_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v});
1454 }
1455 if (prompt_flash_attention_cnode == nullptr) {
1456 MS_LOG(ERROR) << "new cnode failed.";
1457 return nullptr;
1458 }
1459 prompt_flash_attention_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_prompt_flash_attention_bsh");
1460 if (node->abstract() != nullptr) {
1461 prompt_flash_attention_cnode->set_abstract(node->abstract()->Clone());
1462 }
1463 MS_LOG(INFO) << "create PromptFlashAttention success.";
1464 return prompt_flash_attention_cnode;
1465 }
1466
CreateFAForBNSDWithAttenMask(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const CNodePtr & qk_matmul,const CNodePtr & v_matmul,const CNodePtr & attention_mask_mul,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1467 CNodePtr FlashAttentionFusion::CreateFAForBNSDWithAttenMask(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
1468 const CNodePtr &qk_matmul, const CNodePtr &v_matmul,
1469 const CNodePtr &attention_mask_mul,
1470 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1471 auto q = qk_matmul->input(kNumIndex1);
1472 MS_CHECK_TRUE_RET(q != nullptr, nullptr);
1473 auto k = qk_matmul->input(kNumIndex2);
1474 MS_CHECK_TRUE_RET(k != nullptr, nullptr);
1475 auto v = v_matmul->input(kNumIndex2);
1476 MS_CHECK_TRUE_RET(v != nullptr, nullptr);
1477 auto atten_mask = attention_mask_mul->input(1)->cast<CNodePtr>();
1478 MS_CHECK_TRUE_RET(atten_mask != nullptr, nullptr);
1479
1480 auto input_tensor_q_shape = GetTensorShape(qk_matmul, kNumIndex1);
1481 if (input_tensor_q_shape.size() != kNumDimSize4) {
1482 MS_LOG(ERROR) << "q shape is not 4 dims";
1483 return nullptr;
1484 }
1485 auto input_tensor_k_shape = GetTensorShape(qk_matmul, kNumIndex2);
1486 if (input_tensor_k_shape.size() != kNumDimSize4) {
1487 MS_LOG(ERROR) << "k shape is not 4 dims";
1488 return nullptr;
1489 }
1490 auto input_tensor_v_shape = GetTensorShape(v_matmul, kNumIndex2);
1491 if (input_tensor_v_shape.size() != kNumDimSize4) {
1492 MS_LOG(ERROR) << "v shape is not 4 dims";
1493 return nullptr;
1494 }
1495 auto atten_mask_input_shape = GetTensorShape(attention_mask_mul, 1);
1496 if (input_tensor_v_shape.size() != kNumDimSize4) {
1497 MS_LOG(ERROR) << "v shape is not 4 dims";
1498 return nullptr;
1499 }
1500 MS_LOG(INFO) << "q name: " << q->fullname_with_scope() << " , k name: " << k->fullname_with_scope()
1501 << " , v name: " << v->fullname_with_scope()
1502 << ", atten mask name: " << atten_mask->fullname_with_scope();
1503 MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << ", k shape: " << input_tensor_k_shape
1504 << ", v shape: " << input_tensor_v_shape << ", atten mask name: " << atten_mask_input_shape;
1505 // check input shape
1506 if (input_tensor_q_shape[kNumIndex3] <= 0 || input_tensor_q_shape[kNumIndex1] <= 0) {
1507 MS_LOG(ERROR) << "D is -1";
1508 return nullptr;
1509 }
1510 float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
1511 int64_t seq_len = input_tensor_q_shape[kNumIndex2];
1512 int64_t num_key_value_heads = input_tensor_k_shape[1];
1513 if (seq_len != 1) {
1514 return CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
1515 input_tensor_q_shape[kNumIndex1], 0, scale_value, fa_parm,
1516 num_key_value_heads);
1517 } else {
1518 MS_LOG(INFO) << "seq len is 1, incre flash attention.";
1519 return CreateIncreFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
1520 input_tensor_q_shape[kNumIndex1], scale_value, num_key_value_heads);
1521 }
1522 return nullptr;
1523 }
1524
CreateGQACNodeForBNSD(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const CNodePtr & qk_matmul,const CNodePtr & v_matmul,const CNodePtr & attention_mask_mul,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1525 CNodePtr FlashAttentionFusion::CreateGQACNodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
1526 const CNodePtr &qk_matmul, const CNodePtr &v_matmul,
1527 const CNodePtr &attention_mask_mul,
1528 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1529 auto q = qk_matmul->input(kNumIndex1);
1530 MS_CHECK_TRUE_RET(q != nullptr, nullptr);
1531
1532 auto k_reshape = qk_matmul->input(kNumIndex2)->cast<CNodePtr>();
1533 MS_LOG(INFO) << k_reshape->fullname_with_scope();
1534 auto k_tile = k_reshape->input(kNumIndex1)->cast<CNodePtr>();
1535 MS_LOG(INFO) << k_tile->fullname_with_scope();
1536 auto k_expend_dim = k_tile->input(kNumIndex1)->cast<CNodePtr>();
1537
1538 auto v_reshape = v_matmul->input(kNumIndex2)->cast<CNodePtr>();
1539 MS_LOG(INFO) << v_reshape->fullname_with_scope();
1540 auto v_tile = v_reshape->input(kNumIndex1)->cast<CNodePtr>();
1541 MS_LOG(INFO) << v_tile->fullname_with_scope();
1542 auto v_expend_dim = v_tile->input(kNumIndex1)->cast<CNodePtr>();
1543
1544 auto k = k_expend_dim->input(kNumIndex1);
1545 MS_CHECK_TRUE_RET(k != nullptr, nullptr);
1546 auto v = v_expend_dim->input(kNumIndex1);
1547 MS_CHECK_TRUE_RET(v != nullptr, nullptr);
1548
1549 auto atten_mask = attention_mask_mul->input(kNumIndex1)->cast<CNodePtr>();
1550 MS_CHECK_TRUE_RET(atten_mask != nullptr, nullptr);
1551
1552 auto input_tensor_q_shape = GetTensorShape(qk_matmul, kNumIndex1);
1553 if (input_tensor_q_shape.size() != kNumDimSize4) {
1554 MS_LOG(ERROR) << "q shape is not 4 dims";
1555 return nullptr;
1556 }
1557 auto input_tensor_k_shape = GetTensorShape(k_expend_dim, kNumIndex1);
1558 if (input_tensor_k_shape.size() != kNumDimSize4) {
1559 MS_LOG(ERROR) << "k shape is not 4 dims";
1560 return nullptr;
1561 }
1562 auto input_tensor_v_shape = GetTensorShape(v_expend_dim, kNumIndex1);
1563 if (input_tensor_v_shape.size() != kNumDimSize4) {
1564 MS_LOG(ERROR) << "v shape is not 4 dims";
1565 return nullptr;
1566 }
1567 auto atten_mask_input_shape = GetTensorShape(attention_mask_mul, 1);
1568 if (input_tensor_v_shape.size() != kNumDimSize4) {
1569 MS_LOG(ERROR) << "v shape is not 4 dims";
1570 return nullptr;
1571 }
1572 MS_LOG(INFO) << "q name: " << q->fullname_with_scope() << " , k name: " << k->fullname_with_scope()
1573 << " , v name: " << v->fullname_with_scope()
1574 << ", atten mask name: " << atten_mask->fullname_with_scope();
1575 MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << ", k shape: " << input_tensor_k_shape
1576 << ", v shape: " << input_tensor_v_shape << ", atten mask shape: " << atten_mask_input_shape;
1577 // check input shape
1578 if (input_tensor_q_shape[kNumIndex3] <= 0 || input_tensor_q_shape[kNumIndex1] <= 0) {
1579 MS_LOG(ERROR) << "D is -1";
1580 return nullptr;
1581 }
1582 float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
1583 int64_t seq_len = input_tensor_q_shape[kNumIndex2];
1584 int64_t num_key_value_heads = input_tensor_k_shape[1];
1585 if (seq_len != 1) {
1586 return CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
1587 input_tensor_q_shape[kNumIndex1], 0, scale_value, fa_parm,
1588 num_key_value_heads);
1589 } else {
1590 MS_LOG(INFO) << "seq len is 1, incre flash attention.";
1591 return CreateIncreFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
1592 input_tensor_q_shape[kNumIndex1], scale_value, num_key_value_heads);
1593 }
1594 return nullptr;
1595 }
1596
CreateIncreFlashAttentionCnodeForBNSD(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q,const AnfNodePtr & k,const AnfNodePtr & v,const AnfNodePtr & atten_mask,int64_t num_heads,float scale_value,int64_t num_key_value_heads) const1597 CNodePtr FlashAttentionFusion::CreateIncreFlashAttentionCnodeForBNSD(
1598 const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
1599 const AnfNodePtr &atten_mask, int64_t num_heads, float scale_value, int64_t num_key_value_heads) const {
1600 MS_LOG(INFO) << "CreateIncreFlashAttentionCnodeForBNSD";
1601 // create op
1602 auto incre_flash_attention_prim = std::make_shared<ops::IncreFlashAttention>();
1603 if (incre_flash_attention_prim == nullptr) {
1604 MS_LOG(ERROR) << "incre_flash_attention_prim is nullptr.";
1605 return nullptr;
1606 }
1607 // add attr
1608 incre_flash_attention_prim->AddAttr("num_heads", api::MakeValue(num_heads));
1609 incre_flash_attention_prim->AddAttr("input_layout", api::MakeValue("BNSD"));
1610 incre_flash_attention_prim->AddAttr("scale_value", api::MakeValue(scale_value));
1611 incre_flash_attention_prim->AddAttr("num_key_value_heads", api::MakeValue(num_key_value_heads));
1612
1613 std::vector<int64_t> dyn_input_sizes = {-1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1};
1614 incre_flash_attention_prim->AddAttr("dyn_input_sizes", api::MakeValue(dyn_input_sizes));
1615
1616 MS_LOG(INFO) << "num heads: " << num_heads << ", input layout: BNSD, scale value: " << scale_value
1617 << ", num_key_value_heads: " << num_key_value_heads << ", dyn_input_sizes:" << dyn_input_sizes;
1618 auto fa_prim_c = incre_flash_attention_prim->GetPrim();
1619 CNodePtr incre_flash_attention_cnode = nullptr;
1620 if (atten_mask != nullptr) {
1621 incre_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v, atten_mask});
1622 } else {
1623 incre_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v});
1624 }
1625 if (incre_flash_attention_cnode == nullptr) {
1626 MS_LOG(ERROR) << "new cnode failed.";
1627 return nullptr;
1628 }
1629 incre_flash_attention_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_incre_flash_attention");
1630 if (node->abstract() != nullptr) {
1631 incre_flash_attention_cnode->set_abstract(node->abstract()->Clone());
1632 }
1633 MS_LOG(INFO) << "create IncreFlashAttention success.";
1634 return incre_flash_attention_cnode;
1635 }
1636
CreateFAForSD15(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q_trans,const AnfNodePtr & k_trans,const AnfNodePtr & v_trans,int64_t num_head,int64_t next_token,float scale_value,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1637 CNodePtr FlashAttentionFusion::CreateFAForSD15(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
1638 const AnfNodePtr &q_trans, const AnfNodePtr &k_trans,
1639 const AnfNodePtr &v_trans, int64_t num_head, int64_t next_token,
1640 float scale_value,
1641 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1642 MS_LOG(INFO) << "create flash attention for stable diffusion V1.5.";
1643 auto q_pad_node = CreatePadCNode(func_graph, q_trans, kNumPadSize, node->fullname_with_scope());
1644 if (q_pad_node == nullptr) {
1645 MS_LOG(WARNING) << "create q_pad_node failed.";
1646 return nullptr;
1647 }
1648 auto k_pad_node = CreatePadCNode(func_graph, k_trans, kNumPadSize, node->fullname_with_scope());
1649 if (k_pad_node == nullptr) {
1650 MS_LOG(WARNING) << "create q_pad_node failed.";
1651 return nullptr;
1652 }
1653 auto v_pad_node = CreatePadCNode(func_graph, v_trans, kNumPadSize, node->fullname_with_scope());
1654 if (v_pad_node == nullptr) {
1655 MS_LOG(WARNING) << "create q_pad_node failed.";
1656 return nullptr;
1657 }
1658 auto fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_pad_node, k_pad_node, v_pad_node, nullptr,
1659 num_head, next_token, scale_value, fa_parm, num_head);
1660 if (fa_node == nullptr) {
1661 MS_LOG(WARNING) << "create fa_node failed.";
1662 return nullptr;
1663 }
1664 auto slice_node = CreateSliceCNode(func_graph, fa_node, kNumDValue);
1665 return slice_node;
1666 }
1667
CreateFAWithPadAndPse(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q_trans,const AnfNodePtr & k_trans,const AnfNodePtr & v_trans,const AnfNodePtr & pse,int64_t num_head,int64_t next_token,float scale_value,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1668 CNodePtr FlashAttentionFusion::CreateFAWithPadAndPse(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
1669 const AnfNodePtr &q_trans, const AnfNodePtr &k_trans,
1670 const AnfNodePtr &v_trans, const AnfNodePtr &pse, int64_t num_head,
1671 int64_t next_token, float scale_value,
1672 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1673 MS_LOG(INFO) << "create flash attention for stable diffusion with pse input.";
1674 auto q_pad_node = CreatePadCNode(func_graph, q_trans, kNumPadSize);
1675 if (q_pad_node == nullptr) {
1676 MS_LOG(WARNING) << "create q_pad_node failed.";
1677 return nullptr;
1678 }
1679 auto k_pad_node = CreatePadCNode(func_graph, k_trans, kNumPadSize);
1680 if (k_pad_node == nullptr) {
1681 MS_LOG(WARNING) << "create q_pad_node failed.";
1682 return nullptr;
1683 }
1684 auto v_pad_node = CreatePadCNode(func_graph, v_trans, kNumPadSize);
1685 if (v_pad_node == nullptr) {
1686 MS_LOG(WARNING) << "create q_pad_node failed.";
1687 return nullptr;
1688 }
1689 auto fa_node =
1690 CreatePromptFlashAttentionCnodeForBNSDWithPse(func_graph, node, q_pad_node, k_pad_node, v_pad_node, nullptr, pse,
1691 num_head, next_token, scale_value, fa_parm, num_head);
1692 if (fa_node == nullptr) {
1693 MS_LOG(WARNING) << "create fa_node failed.";
1694 return nullptr;
1695 }
1696
1697 auto slice_node = CreateSliceCNode(func_graph, fa_node, kNumDValue);
1698 return slice_node;
1699 }
1700
GetScaleValueForDynamicShape(const AnfNodePtr & mul_const_input) const1701 float FlashAttentionFusion::GetScaleValueForDynamicShape(const AnfNodePtr &mul_const_input) const {
1702 tensor::TensorPtr tensor_info = nullptr;
1703 if (utils::isa<ValueNodePtr>(mul_const_input)) {
1704 auto value_node = mul_const_input->cast<ValueNodePtr>();
1705 if (value_node == nullptr) {
1706 MS_LOG(WARNING) << "value_node is nullptr.";
1707 return -1;
1708 }
1709 auto value = value_node->value();
1710 if (value == nullptr) {
1711 MS_LOG(WARNING) << "value is nullptr.";
1712 return -1;
1713 }
1714 tensor_info = value->cast<tensor::TensorPtr>();
1715 } else if (utils::isa<ParameterPtr>(mul_const_input)) {
1716 // for dynamic shape: get scale value
1717 auto mul_param = mul_const_input->cast<ParameterPtr>()->default_param();
1718 if (mul_param == nullptr) {
1719 MS_LOG(WARNING) << "mul_param is nullptr.";
1720 return -1;
1721 }
1722 tensor_info = mul_param->cast<tensor::TensorPtr>();
1723 } else {
1724 MS_LOG(WARNING) << "mul input is not ParameterPtr or ValueNodePtr.";
1725 return -1;
1726 }
1727 if (tensor_info == nullptr) {
1728 MS_LOG(WARNING) << "tensor info is nullptr.";
1729 return -1;
1730 }
1731 if (tensor_info->data_c() == nullptr) {
1732 MS_LOG(WARNING) << "mul data is nullptr.";
1733 return -1;
1734 }
1735 if (tensor_info->ElementsNum() != 1) {
1736 MS_LOG(WARNING) << "mul value elements num is not 1, ElementsNum is: " << tensor_info->ElementsNum();
1737 return -1;
1738 }
1739 if (tensor_info->data_type() == kNumberTypeFloat32) {
1740 return static_cast<float *>(tensor_info->data_c())[0];
1741 } else if (tensor_info->data_type() == kNumberTypeFloat16) {
1742 return static_cast<float>(static_cast<float16 *>(tensor_info->data_c())[0]);
1743 } else {
1744 MS_LOG(ERROR) << "bot support data type, " << tensor_info->data_type();
1745 return -1;
1746 }
1747 return -1;
1748 }
1749
CreateFlashAttentionNodeForMsSDXL(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1750 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForMsSDXL(
1751 const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
1752 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1753 MS_LOG(INFO) << "flash attention for SDXL";
1754 MS_CHECK_TRUE_RET(node != nullptr, nullptr);
1755 auto matmul_2 = node->cast<CNodePtr>();
1756 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
1757 auto softmax = matmul_2->input(1)->cast<CNodePtr>();
1758 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
1759 auto div = softmax->input(1)->cast<CNodePtr>();
1760 MS_CHECK_TRUE_RET(div != nullptr, nullptr);
1761 auto matmul_1 = div->input(1)->cast<CNodePtr>();
1762 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
1763
1764 auto q = matmul_1->input(1)->cast<CNodePtr>();
1765 MS_CHECK_TRUE_RET(q != nullptr, nullptr);
1766 auto k_trans = matmul_1->input(2)->cast<CNodePtr>();
1767 MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
1768 auto k = k_trans->input(1)->cast<CNodePtr>();
1769 MS_CHECK_TRUE_RET(k != nullptr, nullptr);
1770 auto v = matmul_2->input(2)->cast<CNodePtr>();
1771 MS_CHECK_TRUE_RET(v != nullptr, nullptr);
1772
1773 auto input_tensor_q_shape = GetTensorShape(matmul_1, kNumIndex1);
1774 auto input_tensor_k_shape = GetTensorShape(k_trans, kNumIndex1);
1775 auto input_tensor_v_shape = GetTensorShape(matmul_2, kNumIndex2);
1776
1777 MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
1778 << " , v shape: " << input_tensor_v_shape;
1779 if (input_tensor_q_shape.size() != kNumShapeSize4 || input_tensor_k_shape.size() != kNumShapeSize4 ||
1780 input_tensor_v_shape.size() != kNumShapeSize4) {
1781 MS_LOG(WARNING) << "input shape is not 4 dims";
1782 return nullptr;
1783 }
1784
1785 int64_t next_tokens = kNumMaxNextTokenSize;
1786 float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
1787 int64_t num_head = input_tensor_q_shape[kNumIndex1];
1788
1789 if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
1790 fa_parm->seq_threshold)) {
1791 MS_LOG(INFO) << "shape check failed.";
1792 return nullptr;
1793 }
1794 auto fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, nullptr, num_head, next_tokens,
1795 scale_value, fa_parm, num_head);
1796 MS_CHECK_TRUE_MSG(fa_node != nullptr, nullptr, "create FA failed, fa_node is nullptr.");
1797 auto manager = Manage(func_graph);
1798 (void)manager->Replace(matmul_2, fa_node);
1799 MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
1800 return nullptr;
1801 }
1802
CreateFlashAttentionNodeForMsSDPseShift(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1803 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForMsSDPseShift(
1804 const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
1805 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1806 MS_LOG(INFO) << "flash attention for SD pse shift";
1807 MS_CHECK_TRUE_RET(node != nullptr, nullptr);
1808 // reshape
1809 auto reshape_output = node->cast<CNodePtr>();
1810 MS_CHECK_TRUE_RET(reshape_output != nullptr, nullptr);
1811 // matmul
1812 auto matmul_2 = reshape_output->input(kNumIndex1)->cast<CNodePtr>();
1813 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
1814 // cast
1815 auto cast = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
1816 MS_CHECK_TRUE_RET(cast != nullptr, nullptr);
1817 // softmax
1818 auto softmax = cast->input(kNumIndex1)->cast<CNodePtr>();
1819 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
1820 // add
1821 auto add = softmax->input(kNumIndex1)->cast<CNodePtr>();
1822 MS_CHECK_TRUE_RET(add != nullptr, nullptr);
1823 // mul
1824 auto mul = add->input(kNumIndex1)->cast<CNodePtr>();
1825 MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
1826 // matmul
1827 auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
1828 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
1829
1830 // Q
1831 auto q_reshape = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
1832 MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
1833 auto q_input = q_reshape->input(kNumIndex1)->cast<CNodePtr>();
1834 MS_CHECK_TRUE_RET(q_input != nullptr, nullptr);
1835
1836 // K
1837 auto trans = matmul_1->input(kNumIndex2)->cast<CNodePtr>();
1838 MS_CHECK_TRUE_RET(trans != nullptr, nullptr);
1839 auto k_reshape = trans->input(kNumIndex1)->cast<CNodePtr>();
1840 MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
1841 auto k_input = k_reshape->input(kNumIndex1)->cast<CNodePtr>();
1842 MS_CHECK_TRUE_RET(k_input != nullptr, nullptr);
1843
1844 // V
1845 auto v_reshape = matmul_2->input(kNumIndex2)->cast<CNodePtr>();
1846 MS_CHECK_TRUE_RET(trans != nullptr, nullptr);
1847 auto v_input = v_reshape->input(kNumIndex1)->cast<CNodePtr>();
1848 MS_CHECK_TRUE_RET(v_input != nullptr, nullptr);
1849
1850 auto input_tensor_q_shape = GetTensorShape(q_reshape, kNumIndex1);
1851 auto input_tensor_k_shape = GetTensorShape(k_reshape, kNumIndex1);
1852 auto input_tensor_v_shape = GetTensorShape(v_reshape, kNumIndex1);
1853
1854 if (input_tensor_q_shape.size() != kNumShapeSize4 || input_tensor_k_shape.size() != kNumShapeSize4 ||
1855 input_tensor_v_shape.size() != kNumShapeSize4) {
1856 MS_LOG(WARNING) << "Dynamic shape is not supported, and need check Q or K input tensor shape: "
1857 << input_tensor_q_shape;
1858 return nullptr;
1859 }
1860 if (input_tensor_q_shape.size() != kNumShapeSize4 || input_tensor_k_shape.size() != kNumShapeSize4 ||
1861 input_tensor_v_shape.size() != kNumShapeSize4) {
1862 MS_LOG(WARNING) << "Dynamic shape is not supported, and need check Q or K input tensor shape: "
1863 << input_tensor_q_shape;
1864 return nullptr;
1865 }
1866
1867 float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
1868 int64_t num_head = input_tensor_q_shape[kNumIndex1];
1869 int64_t next_tokens = kNumMaxNextTokenSize;
1870 int64_t d_value = input_tensor_q_shape[kNumIndex3];
1871
1872 // reshape add input2 from (B*N)SS to BNSS
1873 auto manager = Manage(func_graph);
1874 MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
1875
1876 std::vector<int32_t> BNSS_shape = {
1877 static_cast<int32_t>(input_tensor_q_shape[0]), static_cast<int32_t>(input_tensor_q_shape[1]),
1878 static_cast<int32_t>(input_tensor_q_shape[2]), static_cast<int32_t>(input_tensor_k_shape[2])};
1879 auto shape_parm_node = BuildIntVecParameterNode(func_graph, BNSS_shape, node->fullname_with_scope() + "_shape_perm");
1880 MS_CHECK_TRUE_MSG(shape_parm_node != nullptr, nullptr, "create shape_parm_node return nullptr");
1881
1882 std::vector<AnfNodePtr> op_inputs;
1883 auto add_input2 = add->input(kNumIndex2);
1884 if (utils::isa<CNodePtr>(add_input2)) {
1885 auto pse_BN_S_S_node = add_input2->cast<CNodePtr>();
1886 op_inputs = {pse_BN_S_S_node, shape_parm_node};
1887 } else if (utils::isa<ParameterPtr>(add_input2)) {
1888 auto pse_BN_S_S_parm = add_input2->cast<ParameterPtr>();
1889 op_inputs = {pse_BN_S_S_parm, shape_parm_node};
1890 } else {
1891 return nullptr;
1892 }
1893
1894 auto reshape_prim = std::make_shared<ops::Reshape>();
1895 MS_CHECK_TRUE_MSG(reshape_prim != nullptr, nullptr, "create reshape_prim return nullptr");
1896 auto reshape_prim_c = reshape_prim->GetPrim();
1897 MS_CHECK_TRUE_MSG(reshape_prim_c != nullptr, nullptr, "create prim_c return nullptr");
1898 auto pse_BNSS_node = func_graph->NewCNode(reshape_prim_c, op_inputs);
1899 MS_CHECK_TRUE_MSG(pse_BNSS_node != nullptr, nullptr, "create pse_BNSS_node return nullptr");
1900 pse_BNSS_node->set_fullname_with_scope(node->fullname_with_scope() + "_pse_reshape");
1901 if (node->abstract() != nullptr) {
1902 pse_BNSS_node->set_abstract(node->abstract()->Clone());
1903 }
1904
1905 // FA fusion
1906 CNodePtr fa_node = nullptr;
1907 if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
1908 fa_parm->seq_threshold)) {
1909 return nullptr;
1910 }
1911 if (d_value == kNumDValue) {
1912 fa_node = CreateFAWithPadAndPse(func_graph, node, q_input, k_input, v_input, pse_BNSS_node, num_head, next_tokens,
1913 scale_value, fa_parm);
1914 } else {
1915 fa_node =
1916 CreatePromptFlashAttentionCnodeForBNSDWithPse(func_graph, node, q_input, k_input, v_input, nullptr, pse_BNSS_node,
1917 num_head, next_tokens, scale_value, fa_parm, num_head);
1918 }
1919 MS_CHECK_TRUE_RET(fa_node != nullptr, nullptr);
1920 (void)manager->Replace(node, fa_node);
1921 MS_LOG(INFO) << "create prompt flash attention success for stable diffusion pse shift.";
1922 return nullptr;
1923 }
1924
CreateFlashAttentionNodeForMsSD21(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1925 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForMsSD21(
1926 const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
1927 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1928 MS_LOG(INFO) << "flash attention for SD21";
1929 MS_CHECK_TRUE_RET(node != nullptr, nullptr);
1930 auto matmul_2 = node->cast<CNodePtr>();
1931 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
1932 auto softmax = matmul_2->input(1)->cast<CNodePtr>();
1933 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
1934 auto mul = softmax->input(1)->cast<CNodePtr>();
1935 MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
1936 auto matmul_1 = mul->input(1)->cast<CNodePtr>();
1937 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
1938 auto transpose = matmul_1->input(2)->cast<CNodePtr>();
1939 MS_CHECK_TRUE_RET(transpose != nullptr, nullptr);
1940
1941 auto q_reshape = matmul_1->input(1)->cast<CNodePtr>();
1942 MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
1943 auto q_trans = q_reshape->input(1)->cast<CNodePtr>();
1944 MS_CHECK_TRUE_RET(q_trans != nullptr, nullptr);
1945
1946 auto k_reshape = transpose->input(1)->cast<CNodePtr>();
1947 MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
1948 auto k_trans = k_reshape->input(1)->cast<CNodePtr>();
1949 MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
1950
1951 auto v_reshape = matmul_2->input(2)->cast<CNodePtr>();
1952 MS_CHECK_TRUE_RET(v_reshape != nullptr, nullptr);
1953 auto v_trans = v_reshape->input(1)->cast<CNodePtr>();
1954 MS_CHECK_TRUE_RET(v_trans != nullptr, nullptr);
1955
1956 auto input_tensor_q_shape = GetTensorShape(q_reshape, 1);
1957 auto input_tensor_k_shape = GetTensorShape(k_reshape, 1);
1958 auto input_tensor_v_shape = GetTensorShape(v_reshape, 1);
1959
1960 MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
1961 << " , v shape: " << input_tensor_v_shape;
1962 if (input_tensor_q_shape.size() != kNumShapeSize4 || input_tensor_k_shape.size() != kNumShapeSize4 ||
1963 input_tensor_v_shape.size() != kNumShapeSize4) {
1964 MS_LOG(WARNING) << "input shape is not 4 dims";
1965 return nullptr;
1966 }
1967
1968 int64_t next_tokens = kNumMaxNextTokenSize;
1969 float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
1970 int64_t num_head = input_tensor_q_shape[kNumIndex1];
1971
1972 if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
1973 fa_parm->seq_threshold)) {
1974 MS_LOG(INFO) << "shape check failed.";
1975 return nullptr;
1976 }
1977
1978 auto fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
1979 next_tokens, scale_value, fa_parm, num_head);
1980 MS_CHECK_TRUE_MSG(fa_node != nullptr, nullptr, "create FA failed, fa_node is nullptr.");
1981 auto manager = Manage(func_graph);
1982 (void)manager->Replace(matmul_2, fa_node);
1983 MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
1984 return nullptr;
1985 }
1986
CreateFlashAttentionNodeForVideoComposer(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1987 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForVideoComposer(
1988 const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
1989 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1990 MS_LOG(INFO) << "flash attention for wanxin";
1991 MS_CHECK_TRUE_RET(node != nullptr, nullptr);
1992 auto reshape = node->cast<CNodePtr>();
1993 MS_CHECK_TRUE_RET(reshape != nullptr, nullptr);
1994 auto matmul_2 = reshape->input(1)->cast<CNodePtr>();
1995 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
1996 auto cast_2 = matmul_2->input(1)->cast<CNodePtr>();
1997 MS_CHECK_TRUE_RET(cast_2 != nullptr, nullptr);
1998 auto softmax = cast_2->input(1)->cast<CNodePtr>();
1999 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2000 auto cast_1 = softmax->input(1)->cast<CNodePtr>();
2001 MS_CHECK_TRUE_RET(cast_1 != nullptr, nullptr);
2002 auto mul = cast_1->input(1)->cast<CNodePtr>();
2003 MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2004 auto matmul_1 = mul->input(1)->cast<CNodePtr>();
2005 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2006
2007 auto q_reshape = matmul_1->input(1)->cast<CNodePtr>();
2008 MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
2009 auto q_trans = q_reshape->input(1)->cast<CNodePtr>();
2010 MS_CHECK_TRUE_RET(q_trans != nullptr, nullptr);
2011
2012 auto k_trans_2 = matmul_1->input(2)->cast<CNodePtr>();
2013 MS_CHECK_TRUE_RET(k_trans_2 != nullptr, nullptr);
2014 auto k_reshape = k_trans_2->input(1)->cast<CNodePtr>();
2015 MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
2016 auto k_trans = k_reshape->input(1)->cast<CNodePtr>();
2017 MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
2018
2019 auto v_reshape = matmul_2->input(2)->cast<CNodePtr>();
2020 MS_CHECK_TRUE_RET(v_reshape != nullptr, nullptr);
2021 auto v_trans = v_reshape->input(1)->cast<CNodePtr>();
2022 MS_CHECK_TRUE_RET(v_trans != nullptr, nullptr);
2023
2024 auto input_tensor_q_shape = GetTensorShape(q_reshape, 1);
2025 auto input_tensor_k_shape = GetTensorShape(k_reshape, 1);
2026 auto input_tensor_v_shape = GetTensorShape(v_reshape, 1);
2027
2028 MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
2029 << " , v shape: " << input_tensor_v_shape;
2030
2031 float scale_value = 0;
2032 int64_t num_head = 0;
2033 int64_t next_tokens = kNumMaxNextTokenSize;
2034 int64_t d_value = 0;
2035 auto mul_const_input = mul->input(kNumIndex2);
2036
2037 if (input_tensor_q_shape.size() != kNumShapeSize4) {
2038 scale_value = GetScaleValueForDynamicShape(mul_const_input);
2039 d_value = 1 / pow(scale_value, kNumPowerTwo);
2040 MS_LOG(INFO) << "d_value: " << d_value;
2041 // process bnsd shape
2042 MS_LOG(INFO) << "get flash attention param for dynamic shape, scale value is " << scale_value;
2043 std::vector<int32_t> new_shape = {0, 0, -1};
2044 auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
2045 auto output_shape_node = node->cast<CNodePtr>();
2046 output_shape_node->set_input(kNumIndex2, shape_node);
2047 auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
2048 num_head = GetNumHeadForSD(q_trans_reshape);
2049 } else if (input_tensor_q_shape.size() == kNumShapeSize4) {
2050 MS_LOG(INFO) << "get flash attention param for static shape.";
2051 // for static shape: get scale value
2052 scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
2053 num_head = input_tensor_q_shape[kNumIndex1];
2054 d_value = input_tensor_q_shape[kNumIndex3];
2055 } else {
2056 MS_LOG(WARNING) << "need check Q input tensor shape: " << input_tensor_q_shape;
2057 return nullptr;
2058 }
2059 CNodePtr fa_node = nullptr;
2060 if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
2061 fa_parm->seq_threshold)) {
2062 return nullptr;
2063 }
2064 if (d_value == kNumDValue) {
2065 fa_node = CreateFAForSD15(func_graph, node, q_trans, k_trans, v_trans, num_head, next_tokens, scale_value, fa_parm);
2066 } else {
2067 fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
2068 next_tokens, scale_value, fa_parm, num_head);
2069 }
2070 if (fa_node == nullptr) {
2071 return nullptr;
2072 }
2073 auto manager = Manage(func_graph);
2074 (void)manager->Replace(matmul_2, fa_node);
2075 MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
2076 return nullptr;
2077 }
2078
CreateFlashAttentionNodeForSD(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2079 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSD(const std::string &pattern_name,
2080 const FuncGraphPtr &func_graph, const AnfNodePtr &node,
2081 const EquivPtr &equiv,
2082 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2083 auto cnode = node->cast<CNodePtr>();
2084 auto reshape_o2 = cnode;
2085 MS_CHECK_TRUE_RET(reshape_o2 != nullptr, nullptr);
2086 auto output_trans = reshape_o2->input(kNumIndex1)->cast<CNodePtr>();
2087 MS_CHECK_TRUE_RET(output_trans != nullptr, nullptr);
2088 cnode = output_trans->input(kNumIndex1)->cast<CNodePtr>(); // reshape
2089 MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
2090 auto matmul_2 = cnode->input(1)->cast<CNodePtr>();
2091 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2092 auto cast_2 = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2093 MS_CHECK_TRUE_RET(cast_2 != nullptr, nullptr);
2094 auto softmax = cast_2->input(kNumIndex1)->cast<CNodePtr>();
2095 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2096 auto mul = softmax->input(kNumIndex1)->cast<CNodePtr>();
2097 MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2098 auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2099 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2100 auto transpose = matmul_1->input(kNumIndex2)->cast<CNodePtr>();
2101 MS_CHECK_TRUE_RET(transpose != nullptr, nullptr);
2102 auto q_reshape = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
2103 MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
2104 auto k_reshape = transpose->input(kNumIndex1)->cast<CNodePtr>();
2105 MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
2106 auto v_reshape = matmul_2->input(kNumIndex2)->cast<CNodePtr>();
2107 MS_CHECK_TRUE_RET(v_reshape != nullptr, nullptr);
2108
2109 auto q_trans = q_reshape->input(kNumIndex1);
2110 MS_CHECK_TRUE_RET(q_trans != nullptr, nullptr);
2111 auto k_trans = k_reshape->input(kNumIndex1);
2112 MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
2113 auto v_trans = v_reshape->input(kNumIndex1);
2114 MS_CHECK_TRUE_RET(v_trans != nullptr, nullptr);
2115
2116 auto input_tensor_q_shape = GetTensorShape(q_reshape, kNumIndex1);
2117 auto input_tensor_k_shape = GetTensorShape(k_reshape, kNumIndex1);
2118 auto input_tensor_v_shape = GetTensorShape(v_reshape, kNumIndex1);
2119 MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
2120 << " , v shape: " << input_tensor_v_shape;
2121
2122 float scale_value = 0;
2123 int64_t num_head = 0;
2124 int64_t next_tokens = kNumMaxNextTokenSize;
2125 int64_t d_value = 0;
2126 auto mul_const_input = mul->input(kNumIndex2);
2127
2128 if (input_tensor_q_shape.size() != kNumShapeSize4) {
2129 scale_value = GetScaleValueForDynamicShape(mul_const_input);
2130 d_value = 1 / pow(scale_value, kNumPowerTwo);
2131 // process bnsd shape
2132 MS_LOG(INFO) << "get flash attention param for dynamic shape, scale value is " << scale_value;
2133 std::vector<int32_t> new_shape = {0, 0, -1};
2134 auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
2135 auto output_shape_node = node->cast<CNodePtr>();
2136 output_shape_node->set_input(kNumIndex2, shape_node);
2137 auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
2138 num_head = GetNumHeadForSD(q_trans_reshape);
2139 } else if (input_tensor_q_shape.size() == kNumShapeSize4) {
2140 MS_LOG(INFO) << "get flash attention param for static shape.";
2141 // for static shape: get scale value
2142 scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
2143 num_head = input_tensor_q_shape[kNumIndex1];
2144 d_value = input_tensor_q_shape[kNumIndex3];
2145 } else {
2146 MS_LOG(WARNING) << "need check Q input tensor shape: " << input_tensor_q_shape;
2147 return nullptr;
2148 }
2149 CNodePtr fa_node = nullptr;
2150 if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
2151 fa_parm->seq_threshold)) {
2152 return nullptr;
2153 }
2154 if (d_value == kNumDValue) {
2155 fa_node = CreateFAForSD15(func_graph, node, q_trans, k_trans, v_trans, num_head, next_tokens, scale_value, fa_parm);
2156 } else {
2157 fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
2158 next_tokens, scale_value, fa_parm, num_head);
2159 }
2160 if (fa_node == nullptr) {
2161 return nullptr;
2162 }
2163 auto manager = Manage(func_graph);
2164 (void)manager->Replace(cnode, fa_node);
2165 MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
2166 return nullptr;
2167 }
2168
CreateFlashAttentionNodeForSDPreMul(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2169 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDPreMul(
2170 const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2171 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2172 auto output_reshape = node->cast<CNodePtr>();
2173 MS_CHECK_TRUE_RET(output_reshape != nullptr, nullptr);
2174 auto output_trans = output_reshape->input(kNumIndex1)->cast<CNodePtr>();
2175 MS_CHECK_TRUE_RET(output_trans != nullptr, nullptr);
2176 auto matmul_2 = output_trans->input(kNumIndex1)->cast<CNodePtr>();
2177 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2178 auto softmax = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2179 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2180 auto matmul_1 = softmax->input(kNumIndex1)->cast<CNodePtr>();
2181 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2182 auto mul_q = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
2183 MS_CHECK_TRUE_RET(mul_q != nullptr, nullptr);
2184 auto q_trans_BNSD = mul_q->input(kNumIndex1)->cast<CNodePtr>();
2185 MS_CHECK_TRUE_RET(q_trans_BNSD != nullptr, nullptr);
2186 auto mul_k = matmul_1->input(kNumIndex2)->cast<CNodePtr>();
2187 MS_CHECK_TRUE_RET(mul_k != nullptr, nullptr);
2188 auto k_trans_BNDS = mul_k->input(kNumIndex1)->cast<CNodePtr>();
2189 MS_CHECK_TRUE_RET(k_trans_BNDS != nullptr, nullptr);
2190 auto v_trans_BNSD = matmul_2->input(kNumIndex2)->cast<CNodePtr>();
2191 MS_CHECK_TRUE_RET(v_trans_BNSD != nullptr, nullptr);
2192 auto input_tensor_q_shape = GetTensorShape(mul_q, kNumIndex1);
2193 auto input_tensor_k_shape = GetTensorShape(mul_k, kNumIndex1);
2194 auto input_tensor_v_shape = GetTensorShape(matmul_2, kNumIndex2);
2195 MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
2196 << " , v shape: " << input_tensor_v_shape << ". q name: " << q_trans_BNSD->fullname_with_scope()
2197 << ", k name: " << k_trans_BNDS->fullname_with_scope()
2198 << ", v name: " << v_trans_BNSD->fullname_with_scope();
2199
2200 float scale_value = 0;
2201 int64_t num_head = 0;
2202 int64_t next_tokens = kNumMaxNextTokenSize;
2203 int64_t d_value = 0;
2204 if (input_tensor_q_shape.size() == 0 && input_tensor_k_shape.size() == 0 && input_tensor_v_shape.size() == 0) {
2205 MS_CHECK_TRUE_RET(q_trans_BNSD->inputs().size() > kNumIndex1, nullptr);
2206 auto q_reshape = q_trans_BNSD->input(kNumIndex1)->cast<CNodePtr>();
2207 MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
2208 num_head = GetReshapeParam(q_reshape, kNumIndex2);
2209 d_value = GetReshapeParam(q_reshape, kNumIndex3);
2210 MS_LOG(INFO) << "num_head: " << num_head << ", d_value: " << d_value;
2211 } else if (input_tensor_q_shape.size() != kNumShapeSize4 || input_tensor_k_shape.size() != kNumShapeSize4) {
2212 auto pd2_q_conv = PD2DecoderPattern(q_trans_BNSD);
2213 if (IpAdapterPattern(q_trans_BNSD, k_trans_BNDS)) {
2214 if (!GetParamForIpAdapterPattern(q_trans_BNSD, k_trans_BNDS, &num_head, &d_value)) {
2215 MS_LOG(INFO) << "Get parameter for IpAdapterPattern failed";
2216 return nullptr;
2217 }
2218 std::vector<int32_t> new_shape = {0, 0, -1};
2219 auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
2220 MS_CHECK_TRUE_RET(shape_node != nullptr, nullptr);
2221 output_reshape->set_input(kNumIndex2, shape_node);
2222 } else if (pd2_q_conv != nullptr) {
2223 MS_LOG(INFO) << "Dynamic shape pattern: PD2DecoderPattern";
2224 auto pd2_q_conv_input2_shape = GetTensorShape(pd2_q_conv, kNumIndex2);
2225 d_value = pd2_q_conv_input2_shape[0];
2226 num_head = GetNumHeadForSD(q_trans_BNSD->input(kNumIndex1)->cast<CNodePtr>());
2227 } else {
2228 MS_LOG(INFO) << "Dynamic shape is not supported. Can not fusion FA.";
2229 return nullptr;
2230 }
2231 } else {
2232 MS_LOG(INFO) << "get flash attention param for static shape.";
2233 num_head = input_tensor_q_shape[kNumIndex1];
2234 d_value = input_tensor_q_shape[kNumIndex3];
2235 std::swap(input_tensor_k_shape[kNumIndex2], input_tensor_k_shape[kNumIndex3]);
2236 }
2237 scale_value = 1 / (pow(d_value, kNumPowerHalf));
2238 CNodePtr fa_node = nullptr;
2239 if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
2240 fa_parm->seq_threshold)) {
2241 return nullptr;
2242 }
2243 if (d_value == kNumDValue) {
2244 fa_node = CreateFAForSD15(func_graph, node, q_trans_BNSD, k_trans_BNDS, v_trans_BNSD, num_head, next_tokens,
2245 scale_value, fa_parm);
2246 } else {
2247 fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans_BNSD, k_trans_BNDS, v_trans_BNSD,
2248 nullptr, num_head, next_tokens, scale_value, fa_parm, num_head);
2249 }
2250 MS_CHECK_TRUE_RET(fa_node != nullptr, nullptr);
2251 std::vector<int32_t> new_perm = {kNumIndex0, kNumIndex2, kNumIndex1, kNumIndex3};
2252 auto perm_node = BuildIntVecParameterNode(func_graph, new_perm, k_trans_BNDS->fullname_with_scope() + "_new_perm");
2253 MS_CHECK_TRUE_RET(perm_node != nullptr, nullptr);
2254 k_trans_BNDS->cast<CNodePtr>()->set_input(kNumIndex2, perm_node);
2255 auto manager = Manage(func_graph);
2256 MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
2257 (void)manager->Replace(matmul_2, fa_node);
2258 MS_LOG(INFO) << "create prompt flash attention success for pre-mul pattern.";
2259 return nullptr;
2260 }
2261
CreateFlashAttentionNodeForSDWithoutCast(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2262 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDWithoutCast(
2263 const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2264 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2265 MS_LOG(INFO) << "format_bsh: " << fa_parm->format_bsh << ", seq_threshold: " << fa_parm->seq_threshold
2266 << ", inner_precise: " << fa_parm->inner_precise;
2267 auto cnode = node->cast<CNodePtr>();
2268 auto reshape_o2 = cnode;
2269 MS_CHECK_TRUE_RET(reshape_o2 != nullptr, nullptr);
2270 auto output_trans = reshape_o2->input(kNumIndex1)->cast<CNodePtr>();
2271 MS_CHECK_TRUE_RET(output_trans != nullptr, nullptr);
2272 cnode = output_trans->input(kNumIndex1)->cast<CNodePtr>(); // reshape
2273 MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
2274 auto matmul_2 = cnode->input(1)->cast<CNodePtr>();
2275 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2276 auto softmax = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2277 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2278 auto mul = softmax->input(kNumIndex1)->cast<CNodePtr>();
2279 MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2280 auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2281 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2282 auto transpose = matmul_1->input(kNumIndex2)->cast<CNodePtr>();
2283 MS_CHECK_TRUE_RET(transpose != nullptr, nullptr);
2284 auto q_reshape = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
2285 MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
2286 auto k_reshape = transpose->input(kNumIndex1)->cast<CNodePtr>();
2287 MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
2288 auto v_reshape = matmul_2->input(kNumIndex2)->cast<CNodePtr>();
2289 MS_CHECK_TRUE_RET(v_reshape != nullptr, nullptr);
2290
2291 auto q_trans = q_reshape->input(kNumIndex1);
2292 MS_CHECK_TRUE_RET(q_trans != nullptr, nullptr);
2293 auto k_trans = k_reshape->input(kNumIndex1);
2294 MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
2295 auto v_trans = v_reshape->input(kNumIndex1);
2296 MS_CHECK_TRUE_RET(v_trans != nullptr, nullptr);
2297
2298 auto input_tensor_q_shape = GetTensorShape(q_reshape, kNumIndex1);
2299 auto input_tensor_k_shape = GetTensorShape(k_reshape, kNumIndex1);
2300 auto input_tensor_v_shape = GetTensorShape(v_reshape, kNumIndex1);
2301 MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
2302 << " , v shape: " << input_tensor_v_shape;
2303 auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
2304 MS_CHECK_TRUE_RET(q_trans_reshape != nullptr, nullptr);
2305 auto k_trans_reshape = k_trans->cast<CNodePtr>()->input(kNumIndex1);
2306 MS_CHECK_TRUE_RET(k_trans_reshape != nullptr, nullptr);
2307 auto v_trans_reshape = v_trans->cast<CNodePtr>()->input(kNumIndex1);
2308 MS_CHECK_TRUE_RET(v_trans_reshape != nullptr, nullptr);
2309
2310 auto top_matmul_q = q_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
2311 MS_CHECK_TRUE_RET(top_matmul_q != nullptr, nullptr);
2312 auto top_matmul_k = k_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
2313 MS_CHECK_TRUE_RET(top_matmul_k != nullptr, nullptr);
2314 auto top_matmul_v = v_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
2315 MS_CHECK_TRUE_RET(top_matmul_v != nullptr, nullptr);
2316
2317 float scale_value = 0;
2318 int64_t num_head = 0;
2319 int64_t next_tokens = kNumMaxNextTokenSize;
2320 int64_t d_value = 0;
2321 auto mul_const_input = mul->input(kNumIndex2);
2322 bool actual_BSH = false;
2323
2324 if (input_tensor_q_shape.size() != kNumShapeSize4) {
2325 scale_value = GetScaleValueForDynamicShape(mul_const_input);
2326 d_value = 1 / pow(scale_value, kNumPowerTwo);
2327 // process bnsd shape
2328 MS_LOG(INFO) << "get flash attention param for dynamic shape, scale value is " << scale_value;
2329 std::vector<int32_t> new_shape = {0, 0, -1};
2330 auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
2331 auto output_shape_node = node->cast<CNodePtr>();
2332 output_shape_node->set_input(kNumIndex2, shape_node);
2333 num_head = GetNumHeadForSD(q_trans_reshape);
2334 } else {
2335 MS_LOG(INFO) << "get flash attention param for static shape.";
2336 // for static shape: get scale value
2337 scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
2338 num_head = input_tensor_q_shape[kNumIndex1];
2339 d_value = input_tensor_q_shape[kNumIndex3];
2340 }
2341 CNodePtr fa_node = nullptr;
2342 if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
2343 fa_parm->seq_threshold)) {
2344 MS_LOG(INFO) << "Can not pass shape check.";
2345 return nullptr;
2346 }
2347 if (d_value == kNumDValue) {
2348 fa_node = CreateFAForSD15(func_graph, node, q_trans, k_trans, v_trans, num_head, next_tokens, scale_value, fa_parm);
2349 } else if (fa_parm->format_bsh) {
2350 fa_node = CreatePromptFlashAttentionCnodeForBSH(func_graph, node, top_matmul_q, top_matmul_k, top_matmul_v, nullptr,
2351 num_head, next_tokens, scale_value, fa_parm);
2352 actual_BSH = true;
2353 } else {
2354 fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
2355 next_tokens, scale_value, fa_parm, num_head);
2356 }
2357 MS_CHECK_TRUE_RET(fa_node != nullptr, nullptr);
2358 auto manager = Manage(func_graph);
2359 MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
2360 if (actual_BSH) {
2361 (void)manager->Replace(node, fa_node);
2362 } else {
2363 (void)manager->Replace(cnode, fa_node);
2364 }
2365 MS_LOG(INFO) << "create prompt flash attention success for without cast, BSH: " << actual_BSH;
2366 return nullptr;
2367 }
2368
CreateFlashAttentionNodeForPanGu(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2369 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForPanGu(
2370 const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2371 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2372 auto matmul_2 = node->cast<CNodePtr>();
2373 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2374 auto cast_2 = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2375 MS_CHECK_TRUE_RET(cast_2 != nullptr, nullptr);
2376 auto softmax = cast_2->input(kNumIndex1)->cast<CNodePtr>();
2377 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2378 auto add = softmax->input(kNumIndex1)->cast<CNodePtr>();
2379 MS_CHECK_TRUE_RET(add != nullptr, nullptr);
2380 auto atten_mask_mul = add->input(kNumIndex1)->cast<CNodePtr>();
2381 MS_CHECK_TRUE_RET(atten_mask_mul != nullptr, nullptr);
2382 auto cast_1 = add->input(kNumIndex2)->cast<CNodePtr>();
2383 MS_CHECK_TRUE_RET(cast_1 != nullptr, nullptr);
2384 auto matmul_1 = cast_1->input(kNumIndex1)->cast<CNodePtr>();
2385 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2386 auto div = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
2387 MS_CHECK_TRUE_RET(div != nullptr, nullptr);
2388
2389 // PromptFlashAttention input tensor
2390 auto q = div->input(kNumIndex1);
2391 MS_CHECK_TRUE_RET(q != nullptr, nullptr);
2392 auto k = matmul_1->input(kNumIndex2);
2393 MS_CHECK_TRUE_RET(k != nullptr, nullptr);
2394 auto v = matmul_2->input(kNumIndex2);
2395 MS_CHECK_TRUE_RET(v != nullptr, nullptr);
2396 auto atten_mask = atten_mask_mul->input(kNumIndex1)->cast<CNodePtr>();
2397 MS_CHECK_TRUE_RET(atten_mask != nullptr, nullptr);
2398
2399 auto input_tensor_q_shape = GetTensorShape(div, kNumIndex1);
2400 if (input_tensor_q_shape.size() != kNumDimSize4) {
2401 MS_LOG(ERROR) << "q shape is not 4 dims";
2402 return nullptr;
2403 }
2404 auto input_tensor_k_shape = GetTensorShape(matmul_1, kNumIndex2);
2405 if (input_tensor_k_shape.size() != kNumDimSize4) {
2406 MS_LOG(ERROR) << "k shape is not 4 dims";
2407 return nullptr;
2408 }
2409 auto input_tensor_v_shape = GetTensorShape(matmul_2, kNumIndex2);
2410 if (input_tensor_v_shape.size() != kNumDimSize4) {
2411 MS_LOG(ERROR) << "v shape is not 4 dims";
2412 return nullptr;
2413 }
2414 MS_LOG(INFO) << "q name: " << q->fullname_with_scope() << " , k name: " << k->fullname_with_scope()
2415 << " , v name: " << v->fullname_with_scope();
2416 MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << ", k shape: " << input_tensor_k_shape
2417 << ", v shape: " << input_tensor_v_shape;
2418
2419 // check input shape
2420 if (input_tensor_q_shape[kNumIndex3] <= 0 || input_tensor_q_shape[kNumIndex1] <= 0) {
2421 MS_LOG(ERROR) << "D is -1";
2422 return nullptr;
2423 }
2424 float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
2425 int64_t seq_len = input_tensor_q_shape[kNumIndex2];
2426 if (seq_len != 1) {
2427 return CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
2428 input_tensor_q_shape[kNumIndex1], 0, scale_value, fa_parm,
2429 input_tensor_k_shape[kNumIndex1]);
2430 } else {
2431 MS_LOG(INFO) << "seq len is 1, incre flash attention.";
2432 return CreateIncreFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
2433 input_tensor_q_shape[kNumIndex1], scale_value,
2434 input_tensor_q_shape[kNumIndex1]);
2435 }
2436 return nullptr;
2437 }
2438
CreateFlashAttentionNodeForLLAMAPatternV1(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2439 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForLLAMAPatternV1(
2440 const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2441 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2442 auto matmul_2 = node->cast<CNodePtr>();
2443 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2444 auto cast_2 = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2445 MS_CHECK_TRUE_RET(cast_2 != nullptr, nullptr);
2446 auto softmax = cast_2->input(kNumIndex1)->cast<CNodePtr>();
2447 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2448 auto cast_1 = softmax->input(kNumIndex1)->cast<CNodePtr>();
2449 MS_CHECK_TRUE_RET(cast_1 != nullptr, nullptr);
2450 auto add = cast_1->input(kNumIndex1)->cast<CNodePtr>();
2451 MS_CHECK_TRUE_RET(add != nullptr, nullptr);
2452
2453 auto attention_mask_mul = add->input(kNumIndex1)->cast<CNodePtr>();
2454 MS_CHECK_TRUE_RET(attention_mask_mul != nullptr, nullptr);
2455
2456 auto mul = add->input(kNumIndex2)->cast<CNodePtr>();
2457 MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2458 auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2459 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2460
2461 auto pfa_q_shape = GetTensorShape(matmul_1, kNumIndex1);
2462 auto pfa_k_shape = GetTensorShape(matmul_1, kNumIndex2);
2463 auto pfa_v_shape = GetTensorShape(matmul_2, kNumIndex2);
2464 MS_LOG(INFO) << "q shape: " << pfa_q_shape << ", k shape: " << pfa_k_shape << ", v shape: " << pfa_v_shape;
2465
2466 // process for GQA
2467 if (IsGQAPattern(matmul_1, matmul_2)) {
2468 MS_LOG(INFO) << "create GQA node for bnsd.";
2469 return CreateGQACNodeForBNSD(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2470 }
2471 return CreateFAForBNSDWithAttenMask(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2472 }
2473
CreateFlashAttentionNodeForLLAMAPatternV2(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2474 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForLLAMAPatternV2(
2475 const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2476 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2477 auto matmul_2 = node->cast<CNodePtr>();
2478 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2479 auto softmax = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2480 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2481 auto add = softmax->input(kNumIndex1)->cast<CNodePtr>();
2482 MS_CHECK_TRUE_RET(add != nullptr, nullptr);
2483
2484 auto attention_mask_mul = add->input(kNumIndex1)->cast<CNodePtr>();
2485 MS_CHECK_TRUE_RET(attention_mask_mul != nullptr, nullptr);
2486
2487 auto mul = add->input(kNumIndex2)->cast<CNodePtr>();
2488 MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2489 auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2490 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2491
2492 auto pfa_q_shape = GetTensorShape(matmul_1, kNumIndex1);
2493 auto pfa_k_shape = GetTensorShape(matmul_1, kNumIndex2);
2494 auto pfa_v_shape = GetTensorShape(matmul_2, kNumIndex2);
2495 MS_LOG(INFO) << "q shape: " << pfa_q_shape << ", k shape: " << pfa_k_shape << ", v shape: " << pfa_v_shape;
2496
2497 // process for GQA
2498 if (IsGQAPattern(matmul_1, matmul_2)) {
2499 MS_LOG(INFO) << "create GQA node for bnsd.";
2500 return CreateGQACNodeForBNSD(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2501 }
2502 return CreateFAForBNSDWithAttenMask(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2503 }
2504
CreateFlashAttentionNodeForBaiChuanPattern(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2505 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForBaiChuanPattern(
2506 const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2507 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2508 auto matmul_2 = node->cast<CNodePtr>();
2509 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2510 auto softmax = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2511 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2512 auto add = softmax->input(kNumIndex1)->cast<CNodePtr>();
2513 MS_CHECK_TRUE_RET(add != nullptr, nullptr);
2514
2515 auto attention_mask_mul = add->input(kNumIndex1)->cast<CNodePtr>();
2516 MS_CHECK_TRUE_RET(attention_mask_mul != nullptr, nullptr);
2517
2518 auto add_up = add->input(kNumIndex2)->cast<CNodePtr>();
2519 auto mul = add_up->input(kNumIndex1)->cast<CNodePtr>();
2520 MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2521 auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2522 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2523 // process for GQA
2524 auto pfa_q_shape = GetTensorShape(matmul_1, kNumIndex1);
2525 auto pfa_k_shape = GetTensorShape(matmul_1, kNumIndex2);
2526 auto pfa_v_shape = GetTensorShape(matmul_2, kNumIndex2);
2527 MS_LOG(INFO) << "q shape: " << pfa_q_shape << ", k shape: " << pfa_k_shape << ", v shape: " << pfa_v_shape;
2528 if (IsGQAPattern(matmul_1, matmul_2)) {
2529 MS_LOG(INFO) << "create GQA node for BNSD.";
2530 return CreateGQACNodeForBNSD(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2531 }
2532 return CreateFAForBNSDWithAttenMask(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2533 }
2534
CreateFlashAttentionNodeForSDEinsum(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2535 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDEinsum(
2536 const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2537 const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2538 auto cnode = node->cast<CNodePtr>();
2539 auto reshape_o2 = cnode;
2540 MS_CHECK_TRUE_RET(reshape_o2 != nullptr, nullptr);
2541 auto output_trans = reshape_o2->input(kNumIndex1)->cast<CNodePtr>();
2542 MS_CHECK_TRUE_RET(output_trans != nullptr, nullptr);
2543 cnode = output_trans->input(kNumIndex1)->cast<CNodePtr>(); // reshape
2544 MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
2545 auto matmul_2 = cnode->input(1)->cast<CNodePtr>();
2546 MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2547 auto softmax = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2548 MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2549 auto mul = softmax->input(kNumIndex1)->cast<CNodePtr>();
2550 MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2551 auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2552 MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2553 auto q_reshape = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
2554 MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
2555 auto k_reshape = matmul_1->input(kNumIndex2)->cast<CNodePtr>();
2556 MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
2557 auto v_reshape = matmul_2->input(kNumIndex2)->cast<CNodePtr>();
2558 MS_CHECK_TRUE_RET(v_reshape != nullptr, nullptr);
2559
2560 auto q_trans = q_reshape->input(kNumIndex1);
2561 MS_CHECK_TRUE_RET(q_trans != nullptr, nullptr);
2562 auto k_trans = k_reshape->input(kNumIndex1);
2563 MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
2564 auto v_trans = v_reshape->input(kNumIndex1);
2565 MS_CHECK_TRUE_RET(v_trans != nullptr, nullptr);
2566
2567 auto input_tensor_q_shape = GetTensorShape(q_reshape, kNumIndex1);
2568 auto input_tensor_k_shape = GetTensorShape(k_reshape, kNumIndex1);
2569 auto input_tensor_v_shape = GetTensorShape(v_reshape, kNumIndex1);
2570 MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
2571 << " , v shape: " << input_tensor_v_shape;
2572
2573 float scale_value = 0;
2574 int64_t num_head = 0;
2575 int64_t next_tokens = kNumMaxNextTokenSize;
2576 int64_t d_value = 0;
2577 auto mul_const_input = mul->input(kNumIndex2);
2578
2579 if (input_tensor_q_shape.size() != kNumShapeSize4) {
2580 scale_value = GetScaleValueForDynamicShape(mul_const_input);
2581 d_value = 1 / pow(scale_value, kNumPowerTwo);
2582 // process bnsd shape
2583 MS_LOG(INFO) << "get flash attention param for dynamic shape, scale value is " << scale_value;
2584 std::vector<int32_t> new_shape = {0, 0, -1};
2585 auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
2586 auto output_shape_node = node->cast<CNodePtr>();
2587 output_shape_node->set_input(kNumIndex2, shape_node);
2588 auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
2589 num_head = GetNumHeadForSD(q_trans_reshape);
2590 } else if (input_tensor_q_shape.size() == kNumShapeSize4) {
2591 MS_LOG(INFO) << "get flash attention param for static shape.";
2592 // for static shape: get scale value
2593 scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
2594 num_head = input_tensor_q_shape[kNumIndex1];
2595 d_value = input_tensor_q_shape[kNumIndex3];
2596 } else {
2597 MS_LOG(WARNING) << "need check Q input tensor shape: " << input_tensor_q_shape;
2598 return nullptr;
2599 }
2600 CNodePtr fa_node = nullptr;
2601 if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
2602 fa_parm->seq_threshold)) {
2603 return nullptr;
2604 }
2605
2606 if (d_value == kNumDValue) {
2607 fa_node = CreateFAForSD15(func_graph, node, q_trans, k_trans, v_trans, num_head, next_tokens, scale_value, fa_parm);
2608 } else {
2609 fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
2610 next_tokens, scale_value, fa_parm, num_head);
2611 }
2612 if (fa_node == nullptr) {
2613 return nullptr;
2614 }
2615 auto manager = Manage(func_graph);
2616 (void)manager->Replace(cnode, fa_node);
2617 MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
2618 return nullptr;
2619 }
2620
ParseFAParam() const2621 std::shared_ptr<FlashAttentionParm> FlashAttentionFusion::ParseFAParam() const {
2622 FlashAttentionParm fa_param;
2623 // op_attrs=FlashAttention:input_layout:BSH;
2624 // FlashAttention:seq_threshold:1024;
2625 // FlashAttention:inner_precise:1;
2626 // FlashAttention:sparse_mode:0
2627 if (op_attrs_map_.find("FlashAttention") != op_attrs_map_.end()) {
2628 auto attr_map = op_attrs_map_.at("FlashAttention");
2629 for (const auto &attr : attr_map) {
2630 auto attr_value = attr.second;
2631 if (attr.first == "input_layout") {
2632 if (strcmp(attr_value.c_str(), "BSH") == 0) {
2633 fa_param.format_bsh = true;
2634 MS_LOG(INFO) << "Use user config, FA input_layout is: " << fa_param.format_bsh;
2635 } else if (strcmp(attr_value.c_str(), "BNSD") == 0) {
2636 fa_param.format_bsh = false;
2637 MS_LOG(INFO) << "Use user config, FA input_layout is: " << fa_param.format_bsh;
2638 } else {
2639 MS_LOG(WARNING) << "FA input_layout only supports BSH and BNSD, but get " << attr_value;
2640 return nullptr;
2641 }
2642 } else if (attr.first == "seq_threshold") {
2643 int seq_threshold = std::atoi(attr_value.c_str());
2644 if (std::to_string(seq_threshold) == attr_value && seq_threshold >= 0) {
2645 fa_param.seq_threshold = mindspore::IntToLong(seq_threshold);
2646 MS_LOG(INFO) << "Use user config, FA seq_threshold is: " << fa_param.seq_threshold;
2647 } else {
2648 MS_LOG(WARNING) << "FA seq_threshold only supports (>0 and int) number, but get " << attr_value;
2649 return nullptr;
2650 }
2651 } else if (attr.first == "inner_precise") {
2652 if (FlashAttentionFusion::GetSocVersion() == kSocVersionAscend310P) {
2653 MS_LOG(WARNING) << "FA inner_precise is not supported on Ascend310P.";
2654 return nullptr;
2655 }
2656 int inner_precise = std::atoi(attr_value.c_str());
2657 if (std::to_string(inner_precise) == attr_value && (inner_precise == 0 || inner_precise == 1)) {
2658 MS_LOG(INFO) << "Use user config, FA inner_precise is: " << attr_value;
2659 fa_param.inner_precise = inner_precise;
2660 } else {
2661 MS_LOG(WARNING) << "FA inner_precise only supports 0 or 1, but get " << attr_value;
2662 return nullptr;
2663 }
2664 } else if (attr.first == "sparse_mode") {
2665 if (FlashAttentionFusion::GetSocVersion() != kSocVersionAscend310P) {
2666 MS_LOG(WARNING) << "FA sparse_mode is only supported on Ascend310P, but get env "
2667 << FlashAttentionFusion::GetSocVersion();
2668 return nullptr;
2669 }
2670 int sparse_mode = std::atoi(attr_value.c_str());
2671 if (std::to_string(sparse_mode) == attr_value && (sparse_mode == 0 || sparse_mode == 10)) {
2672 MS_LOG(INFO) << "Use user config, FA sparse_mode is: " << attr_value;
2673 fa_param.inner_precise = sparse_mode;
2674 } else {
2675 MS_LOG(WARNING) << "FA sparse_mode only supports 0 or 10, but get " << attr_value;
2676 return nullptr;
2677 }
2678 } else {
2679 MS_LOG(WARNING) << "FA attr only supports input_layout, seq_threshold, inner_precise and sparse_mode, but get "
2680 << attr.first;
2681 return nullptr;
2682 }
2683 }
2684 }
2685 auto fa_param_ptr = std::make_shared<FlashAttentionParm>(fa_param);
2686 return fa_param_ptr;
2687 }
2688
Process(const std::string & patten_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const2689 AnfNodePtr FlashAttentionFusion::Process(const std::string &patten_name, const FuncGraphPtr &func_graph,
2690 const AnfNodePtr &node, const EquivPtr &equiv) const {
2691 MS_LOG(INFO) << "do flash attention fusion, pattern name: " << patten_name;
2692 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
2693 MS_CHECK_TRUE_RET(node != nullptr, nullptr);
2694 MS_CHECK_TRUE_RET(equiv != nullptr, nullptr);
2695 if (!utils::isa<CNodePtr>(node)) {
2696 MS_LOG(ERROR) << "this node is not cnode, node name: " << node->fullname_with_scope();
2697 return nullptr;
2698 }
2699 if (IsMarkedTrainOp(utils::cast<CNodePtr>(node))) {
2700 MS_LOG(ERROR) << "node is train op, can not fusion.";
2701 return nullptr;
2702 }
2703 auto manager = Manage(func_graph);
2704 MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
2705 CNodePtr flash_attention_node = nullptr;
2706 auto fa_param_ptr = ParseFAParam();
2707 MS_CHECK_TRUE_RET(fa_param_ptr != nullptr, nullptr);
2708 if (patten_name == kNameFlashAttentionPatternForSDBSH) {
2709 MS_LOG(INFO) << "start create flash attention node for stable diffusion.";
2710 flash_attention_node = CreateFlashAttentionNodeForSD(patten_name, func_graph, node, equiv, fa_param_ptr);
2711 } else if (patten_name == kNameFlashAttentionPatternForPanGu) {
2712 MS_LOG(INFO) << "start create flash attention node for PanGu models.";
2713 flash_attention_node = CreateFlashAttentionNodeForPanGu(patten_name, func_graph, node, equiv, fa_param_ptr);
2714 } else if (patten_name == kNameFlashAttentionPatternForLLAMAPatternV1) {
2715 MS_LOG(INFO) << "start create flash attention node for LLAMAV1 Pattern V1.";
2716 flash_attention_node =
2717 CreateFlashAttentionNodeForLLAMAPatternV1(patten_name, func_graph, node, equiv, fa_param_ptr);
2718 } else if (patten_name == kNameFlashAttentionPatternForMsSDPseShift) {
2719 MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion pse-shift.";
2720 flash_attention_node = CreateFlashAttentionNodeForMsSDPseShift(patten_name, func_graph, node, equiv, fa_param_ptr);
2721 } else if (patten_name == kNameFlashAttentionPatternForLLAMAPatternV2) {
2722 MS_LOG(INFO) << "start create flash attention node for LLAMAV1 Pattern V2.";
2723 flash_attention_node =
2724 CreateFlashAttentionNodeForLLAMAPatternV2(patten_name, func_graph, node, equiv, fa_param_ptr);
2725 } else if (patten_name == kNameFlashAttentionPatternForBaiChuan) {
2726 MS_LOG(INFO) << "start create flash attention node for BaiChuan models.";
2727 flash_attention_node =
2728 CreateFlashAttentionNodeForBaiChuanPattern(patten_name, func_graph, node, equiv, fa_param_ptr);
2729 } else if (patten_name == kNameFlashAttentionPatternForVideoComposer) {
2730 MS_LOG(INFO) << "start create flash attention node for Video Composer models.";
2731 flash_attention_node = CreateFlashAttentionNodeForVideoComposer(patten_name, func_graph, node, equiv, fa_param_ptr);
2732 } else if (patten_name == kNameFlashAttentionPatternForMsSDXL) {
2733 MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion XL version.";
2734 flash_attention_node = CreateFlashAttentionNodeForMsSDXL(patten_name, func_graph, node, equiv, fa_param_ptr);
2735 } else if (patten_name == kNameFlashAttentionPatternForMsSD21) {
2736 MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion 2.1 version.";
2737 flash_attention_node = CreateFlashAttentionNodeForMsSD21(patten_name, func_graph, node, equiv, fa_param_ptr);
2738 } else if (patten_name == kNameFlashAttentionPatternForSDPreMul) {
2739 MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion PreMul.";
2740 flash_attention_node = CreateFlashAttentionNodeForSDPreMul(patten_name, func_graph, node, equiv, fa_param_ptr);
2741 } else if (patten_name == kNameFlashAttentionPatternForSDWithoutCast) {
2742 MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion without cast.";
2743 flash_attention_node = CreateFlashAttentionNodeForSDWithoutCast(patten_name, func_graph, node, equiv, fa_param_ptr);
2744 } else if (patten_name == kNameFlashAttentionPatternForSDEinsum) {
2745 MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion with Einsum.";
2746 flash_attention_node = CreateFlashAttentionNodeForSDEinsum(patten_name, func_graph, node, equiv, fa_param_ptr);
2747 } else {
2748 MS_LOG(ERROR) << " not pattern.";
2749 }
2750 if (flash_attention_node == nullptr) {
2751 MS_LOG(INFO) << "flash attention op not fusion.";
2752 return nullptr;
2753 }
2754 manager->Replace(node, flash_attention_node);
2755 MS_LOG(INFO) << "flash attention node fusion success, fusion node name: "
2756 << flash_attention_node->fullname_with_scope();
2757 return flash_attention_node;
2758 }
2759 } // namespace mindspore::opt
2760