• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #define USE_DEPRECATED_API
17 #include "tools/optimizer/fusion/flash_attention_fusion.h"
18 #include <memory>
19 #include <utility>
20 #include <string>
21 #include "ops/op_utils.h"
22 #include "ops/array_ops.h"
23 #include "ops/nn_ops.h"
24 #include "tools/optimizer/common/gllo_utils.h"
25 #include "mindspore/core/ops/lite_ops.h"
26 #include "ops/incre_flash_attention.h"
27 #include "ops/prompt_flash_attention.h"
28 #include "ops/fusion/pad_fusion.h"
29 #include "ops/slice.h"
30 #include "ops/auto_generate/gen_lite_ops.h"
31 
32 namespace mindspore::opt {
33 namespace {
34 static int kNameIndex = 0;
35 constexpr auto kNameFlashAttentionPatternForMsSD21 = "FlashAttentionPatternForMsSD21";
36 constexpr auto kNameFlashAttentionPatternForMsSDXL = "FlashAttentionPatternForMsSDXL";
37 constexpr auto kNameFlashAttentionPatternForVideoComposer = "FlashAttentionPatternForVideoComposer";
38 constexpr auto kNameFlashAttentionPatternForSDBSH = "FlashAttentionPatternForSDBSH";
39 constexpr auto kNameFlashAttentionPatternForSDWithoutCast = "FlashAttentionPatternForSDWithoutCast";
40 constexpr auto kNameFlashAttentionPatternForSDPreMul = "FlashAttentionPatternForSDPreMul";
41 constexpr auto kNameFlashAttentionPatternForPanGu = "FlashAttentionPatternForPanGu";
42 constexpr auto kNameFlashAttentionPatternForLLAMAPatternV1 = "FlashAttentionPatternForLLAMAPatternV1";
43 constexpr auto kNameFlashAttentionPatternForLLAMAPatternV2 = "FlashAttentionPatternForLLAMAPatternV2";
44 constexpr auto kNameFlashAttentionPatternForBaiChuan = "FlashAttentionPatternForBaiChuan";
45 constexpr auto kNameFlashAttentionPatternForMsSDPseShift = "FlashAttentionPatternForMsSDPseShift";
46 constexpr auto kNameFlashAttentionPatternForSDEinsum = "FlashAttentionPatternForSDEinsum";
47 constexpr auto kNamePadNodeSuffix = "_fa_pad";
48 constexpr auto kSocVersionAscend310P = "Ascend310P3";
49 constexpr size_t high_inner_precise = 0;
50 constexpr size_t high_performance = 1;
51 constexpr size_t kNumIndex0 = 0;
52 constexpr size_t kNumIndex1 = 1;
53 constexpr size_t kNumIndex2 = 2;
54 constexpr size_t kNumIndex3 = 3;
55 constexpr size_t kNumDimSize4 = 4;
56 constexpr size_t kNumShapeSize2 = 2;
57 constexpr size_t kNumShapeSize3 = 3;
58 constexpr size_t kNumShapeSize4 = 4;
59 constexpr int64_t kNumMaxBatchLenSize = 128;
60 constexpr int64_t kNumMaxNextTokenSize = 65535;
61 constexpr int kNumMultiple32 = 32;
62 constexpr int kNumMultiple16 = 16;
63 constexpr int64_t kNumDValue = 40;
64 constexpr int64_t kNumPadSize = 8;
65 constexpr int kNumPowerTwo = 2;
66 constexpr float kNumPowerHalf = 0.5;
67 
IsDivNode(const BaseRef & n)68 bool IsDivNode(const BaseRef &n) {
69   if (utils::isa<AnfNodePtr>(n)) {
70     auto anf_node = utils::cast<AnfNodePtr>(n);
71     return CheckPrimitiveType(anf_node, prim::kPrimDiv) || CheckPrimitiveType(anf_node, prim::kPrimRealDiv);
72   }
73   return false;
74 }
75 
IsGQAPattern(const CNodePtr qk_matmul,const CNodePtr v_matmul)76 bool IsGQAPattern(const CNodePtr qk_matmul, const CNodePtr v_matmul) {
77   auto k_reshape = qk_matmul->input(kNumIndex2)->cast<CNodePtr>();
78   if (!CheckPrimitiveType(k_reshape, prim::kPrimReshape)) {
79     return false;
80   }
81   auto k_tile = k_reshape->input(kNumIndex1)->cast<CNodePtr>();
82   if (!CheckPrimitiveType(k_tile, prim::kPrimTile)) {
83     return false;
84   }
85   auto v_reshape = v_matmul->input(kNumIndex2)->cast<CNodePtr>();
86   if (!CheckPrimitiveType(v_reshape, prim::kPrimReshape)) {
87     return false;
88   }
89   auto v_tile = v_reshape->input(kNumIndex1)->cast<CNodePtr>();
90   if (!CheckPrimitiveType(v_tile, prim::kPrimTile)) {
91     return false;
92   }
93   return true;
94 }
95 
PFACheckShape(float scale_value,const std::vector<int64_t> & q_shape,const std::vector<int64_t> & k_shape,const std::vector<int64_t> & v_shape,int64_t seq_threshold=0)96 bool PFACheckShape(float scale_value, const std::vector<int64_t> &q_shape, const std::vector<int64_t> &k_shape,
97                    const std::vector<int64_t> &v_shape, int64_t seq_threshold = 0) {
98   if (scale_value < 0) {
99     MS_LOG(WARNING) << "scale value is invalid.";
100     return false;
101   }
102   int64_t d_value = 0;
103   if (q_shape.size() == kNumShapeSize4 && k_shape.size() == kNumShapeSize4 && v_shape.size() == kNumShapeSize4) {
104     MS_LOG(INFO) << "get flash attention param for static shape.";
105     if (q_shape[kNumIndex0] >= kNumMaxBatchLenSize) {
106       MS_LOG(INFO) << "fa not support";
107       return false;
108     }
109     // for static shape: get scale value
110     scale_value = 1 / (pow(q_shape[kNumIndex3], kNumPowerHalf));
111     auto q_seq_len = q_shape[kNumIndex2];
112     auto k_seq_len = k_shape[kNumIndex2];
113     auto v_seq_len = v_shape[kNumIndex2];
114     d_value = q_shape[kNumIndex3];
115     MS_LOG(INFO) << "check param in stable diffusion models, scale_value: " << scale_value
116                  << ", q_seq_len: " << q_seq_len << ", k_seq_len: " << k_seq_len << ", v_seq_len: " << v_seq_len
117                  << ", d_value: " << d_value;
118     if (q_seq_len < seq_threshold || k_seq_len < seq_threshold || v_seq_len < seq_threshold) {
119       MS_LOG(INFO) << "seq <= seq_threshold, seq_threshold is: " << seq_threshold;
120       return false;
121     }
122   } else {
123     d_value = std::lround(1 / pow(scale_value, kNumPowerTwo));
124   }
125   if (static_cast<int>(d_value) % kNumMultiple16 == 0 || static_cast<int>(d_value) == kNumDValue) {
126     return true;
127   }
128   MS_LOG(INFO) << "D value must be an integer multiple of 16 or D is 40, d value: " << static_cast<int>(d_value);
129   return false;
130 }
131 
GetReshapeParam(const AnfNodePtr & reshape_node,size_t index)132 int32_t GetReshapeParam(const AnfNodePtr &reshape_node, size_t index) {
133   if (!utils::isa<CNodePtr>(reshape_node)) {
134     MS_LOG(INFO) << "reshape_node is not CNode!";
135     return -1;
136   }
137   auto reshape_cnode = reshape_node->cast<CNodePtr>();
138   if (reshape_cnode->inputs().size() < kNumShapeSize3) {
139     MS_LOG(WARNING) << "reshape_cnode size < 3!";
140     return -1;
141   }
142   auto reshape_input_2 = reshape_cnode->input(kNumIndex2);
143   if (!utils::isa<ParameterPtr>(reshape_input_2)) {
144     MS_LOG(INFO) << "reshape_input_2 is not ParameterPtr!";
145     return -1;
146   }
147   auto reshape_param = reshape_input_2->cast<ParameterPtr>();
148   if (reshape_param == nullptr) {
149     MS_LOG(WARNING) << "reshape_param is nullptr!";
150     return -1;
151   }
152   auto reshape_default_param = reshape_param->default_param();
153   if (reshape_default_param == nullptr) {
154     MS_LOG(WARNING) << "reshape_default_param is nullptr!";
155     return -1;
156   }
157   auto reshape_value = std::dynamic_pointer_cast<tensor::Tensor>(reshape_default_param);
158   if (reshape_value == nullptr) {
159     MS_LOG(WARNING) << "reshape_value is nullptr!";
160     return -1;
161   }
162   if (reshape_value->ElementsNum() != kNumShapeSize4) {
163     MS_LOG(WARNING) << "reshape_value elements num is not 4, ElementsNum is: " << reshape_value->ElementsNum();
164     return -1;
165   }
166   if (reshape_value->data_type() != kNumberTypeInt32) {
167     MS_LOG(WARNING) << "reshape_value is not or int32, now not support other data type.";
168     return -1;
169   }
170   auto reshape_data = static_cast<int32_t *>(reshape_value->data_c());
171   if (reshape_data == nullptr) {
172     MS_LOG(WARNING) << "reshape_data is nullptr.";
173     return -1;
174   }
175   return static_cast<int64_t>(reshape_data[index]);
176 }
177 
GetNumHeadForSD(const AnfNodePtr & q_trans_reshape)178 int64_t GetNumHeadForSD(const AnfNodePtr &q_trans_reshape) {
179   auto concat_cnode = q_trans_reshape->cast<CNodePtr>()->input(kNumIndex2)->cast<CNodePtr>();
180   if (concat_cnode == nullptr) {
181     MS_LOG(WARNING) << "concat_cnode is nullptr.";
182     return -1;
183   }
184   auto concat_const_input = concat_cnode->input(kNumIndex3);
185   if (!utils::isa<ParameterPtr>(concat_const_input)) {
186     MS_LOG(WARNING) << "concat_const_input is not ParameterPtr .";
187     return -1;
188   }
189   auto concat_param = concat_cnode->input(kNumIndex3)->cast<ParameterPtr>()->default_param();
190   if (concat_param == nullptr) {
191     MS_LOG(WARNING) << "concat_param is nullptr.";
192     return -1;
193   }
194   auto concat_value = std::dynamic_pointer_cast<tensor::Tensor>(concat_param);
195   if (concat_value == nullptr) {
196     MS_LOG(WARNING) << "concat_value is nullptr.";
197     return -1;
198   }
199   if (concat_value->ElementsNum() != 1) {
200     MS_LOG(WARNING) << "concat value elements num is not 1, ElementsNum is: " << concat_value->ElementsNum();
201     return -1;
202   }
203   if (concat_value->data_type() == kNumberTypeInt32) {
204     auto concat_data = static_cast<int32_t *>(concat_value->data_c());
205     if (concat_data == nullptr) {
206       MS_LOG(WARNING) << "concat_data is nullptr.";
207       return -1;
208     }
209     return static_cast<int64_t>(concat_data[0]);
210   } else if (concat_value->data_type() == kNumberTypeInt64) {
211     auto concat_data = static_cast<int64_t *>(concat_value->data_c());
212     if (concat_data == nullptr) {
213       MS_LOG(WARNING) << "concat_data is nullptr.";
214       return -1;
215     }
216     return static_cast<int64_t>(concat_data[0]);
217   } else {
218     MS_LOG(WARNING) << "head num is not int32 or int64, now not support other data type.";
219     return -1;
220   }
221 }
222 
223 // reshape(matmul, concat)->trans(reshape)->FA
CheckIpAdapterInput(const CNodePtr & input)224 bool CheckIpAdapterInput(const CNodePtr &input) {
225   auto trans = input;
226   MS_CHECK_TRUE_RET(trans != nullptr, false);
227   if (!CheckPrimitiveType(trans, prim::kPrimTranspose)) {
228     MS_LOG(INFO) << "node is not check op type: " << trans->fullname_with_scope();
229     return false;
230   }
231 
232   if (trans->inputs().size() != kNumShapeSize3) {
233     return false;
234   }
235   auto reshape = trans->input(kNumIndex1)->cast<CNodePtr>();
236   MS_CHECK_TRUE_RET(reshape != nullptr, false);
237   if (!CheckPrimitiveType(reshape, prim::kPrimReshape)) {
238     MS_LOG(INFO) << "node is not check op type: " << reshape->fullname_with_scope();
239     return false;
240   }
241 
242   if (reshape->inputs().size() != kNumShapeSize3) {
243     return false;
244   }
245   auto matmul = reshape->input(kNumIndex1)->cast<CNodePtr>();
246   MS_CHECK_TRUE_RET(matmul != nullptr, false);
247   if (!CheckPrimitiveType(matmul, prim::kPrimMatMulFusion)) {
248     MS_LOG(INFO) << "node is not check op type: " << matmul->fullname_with_scope();
249     return false;
250   }
251   auto concat = reshape->input(kNumIndex2)->cast<CNodePtr>();
252   MS_CHECK_TRUE_RET(concat != nullptr, false);
253   if (!CheckPrimitiveType(concat, prim::kPrimConcat)) {
254     MS_LOG(INFO) << "node is not check op type: " << concat->fullname_with_scope();
255     return false;
256   }
257   return true;
258 }
259 
IpAdapterPattern(const CNodePtr q_input,const CNodePtr k_input)260 bool IpAdapterPattern(const CNodePtr q_input, const CNodePtr k_input) {
261   return CheckIpAdapterInput(q_input) && CheckIpAdapterInput(k_input);
262 }
263 
PD2DecoderPattern(const CNodePtr & q_trans_BNSD)264 const CNodePtr PD2DecoderPattern(const CNodePtr &q_trans_BNSD) {
265   auto q_reshape_BSND = q_trans_BNSD->input(kNumIndex1)->cast<CNodePtr>();
266   MS_CHECK_TRUE_RET(q_reshape_BSND != nullptr && q_reshape_BSND->inputs().size() == kNumShapeSize3, nullptr);
267   if (!CheckPrimitiveType(q_reshape_BSND, prim::kPrimReshape)) {
268     MS_LOG(INFO) << "node is not check op type: " << q_reshape_BSND->fullname_with_scope();
269     return nullptr;
270   }
271 
272   auto q_trans_BSH = q_reshape_BSND->input(kNumIndex1)->cast<CNodePtr>();
273   MS_CHECK_TRUE_RET(q_trans_BSH != nullptr && q_trans_BSH->inputs().size() == kNumShapeSize3, nullptr);
274   if (!CheckPrimitiveType(q_trans_BSH, prim::kPrimTranspose)) {
275     MS_LOG(INFO) << "node is not check op type: " << q_trans_BSH->fullname_with_scope();
276     return nullptr;
277   }
278 
279   auto q_reshape_BSH = q_trans_BSH->input(kNumIndex1)->cast<CNodePtr>();
280   MS_CHECK_TRUE_RET(q_reshape_BSH != nullptr && q_reshape_BSH->inputs().size() == kNumShapeSize3, nullptr);
281   if (!CheckPrimitiveType(q_reshape_BSH, prim::kPrimReshape)) {
282     MS_LOG(INFO) << "node is not check op type: " << q_reshape_BSH->fullname_with_scope();
283     return nullptr;
284   }
285 
286   auto q_conv = q_reshape_BSH->input(kNumIndex1)->cast<CNodePtr>();
287   if (!CheckPrimitiveType(q_conv, prim::kPrimConv2DFusion)) {
288     MS_LOG(INFO) << "node is not check op type: " << q_conv->fullname_with_scope();
289     return nullptr;
290   }
291   return q_conv;
292 }
293 
GetTensorShape(CNodePtr cnode,size_t input_index)294 std::vector<int64_t> GetTensorShape(CNodePtr cnode, size_t input_index) {
295   auto abstract = GetCNodeInputAbstract(cnode, input_index);
296   if (abstract == nullptr) {
297     MS_LOG(ERROR) << "GetCNodeInputAbstract in promapt flash attention fusion.";
298     return {};
299   }
300   std::vector<int64_t> shape = {};
301   if (FetchShapeFromAbstract(abstract, &shape) != lite::RET_OK) {
302     MS_LOG(ERROR) << "FetchShapeFromAbstract failed.";
303     return {};
304   }
305   return shape;
306 }
307 
GetParamForIpAdapterPattern(const CNodePtr & q_trans_BNSD,const CNodePtr & k_trans_BNDS,int64_t * num_head,int64_t * d_value)308 bool GetParamForIpAdapterPattern(const CNodePtr &q_trans_BNSD, const CNodePtr &k_trans_BNDS, int64_t *num_head,
309                                  int64_t *d_value) {
310   if (num_head == nullptr) {
311     MS_LOG(ERROR) << "GetParamForIpAdapterPattern failed, num_head is nullptr!";
312     return false;
313   }
314   if (d_value == nullptr) {
315     MS_LOG(ERROR) << "GetParamForIpAdapterPattern failed, d_value is nullptr!";
316     return false;
317   }
318   auto q_reshape_BSND = q_trans_BNSD->input(kNumIndex1)->cast<CNodePtr>();
319   MS_CHECK_TRUE_RET(q_reshape_BSND != nullptr, false);
320   auto q_top_matmul = q_reshape_BSND->input(kNumIndex1)->cast<CNodePtr>();
321   MS_CHECK_TRUE_RET(q_top_matmul != nullptr, false);
322   auto q_matmul_input_2_shape = GetTensorShape(q_top_matmul, kNumIndex2);
323   auto k_reshape_BSND = k_trans_BNDS->input(kNumIndex1)->cast<CNodePtr>();
324   MS_CHECK_TRUE_RET(k_reshape_BSND != nullptr, false);
325   auto k_top_matmul = k_reshape_BSND->input(kNumIndex1)->cast<CNodePtr>();
326   MS_CHECK_TRUE_RET(k_top_matmul != nullptr, false);
327   auto k_matmul_input_2_shape = GetTensorShape(k_top_matmul, kNumIndex2);
328   MS_LOG(INFO) << "q_top_matmul name: " << q_top_matmul->fullname_with_scope()
329                << ", k_top_matmul name: " << k_top_matmul->fullname_with_scope()
330                << ", q matmul input2 shape: " << q_matmul_input_2_shape
331                << " ,k matmul input2 shape: " << k_matmul_input_2_shape;
332   if (q_matmul_input_2_shape.size() != kNumShapeSize2 || k_matmul_input_2_shape.size() != kNumShapeSize2) {
333     MS_LOG(INFO) << "Matmul input 2 shape is not 2D, can not fusion FA.";
334     return false;
335   }
336   auto num_h = q_matmul_input_2_shape[kNumIndex1];
337   *num_head = GetNumHeadForSD(q_reshape_BSND);
338   *d_value = num_h / *num_head;
339   return true;
340 }
341 }  // namespace
342 
343 std::string FlashAttentionFusion::soc_version_;
344 
DefinePatterns() const345 std::unordered_map<std::string, VectorRef> FlashAttentionFusion::DefinePatterns() const {
346   MS_LOG(INFO) << "start define flash attention fusion patterns.";
347   std::unordered_map<std::string, VectorRef> patterns;
348   patterns[kNameFlashAttentionPatternForMsSD21] = DefineFlashAttentionPatternForMsSD21();
349   patterns[kNameFlashAttentionPatternForMsSDXL] = DefineFlashAttentionPatternForMsSDXL();
350   patterns[kNameFlashAttentionPatternForVideoComposer] = DefineFlashAttentionPatternForVideoComposer();
351   patterns[kNameFlashAttentionPatternForSDBSH] = DefineFlashAttentionPatternForSDBSH();
352   patterns[kNameFlashAttentionPatternForSDWithoutCast] = DefineFlashAttentionPatternForSDWithoutCast();
353   patterns[kNameFlashAttentionPatternForSDPreMul] = DefineFlashAttentionPatternForSDPreMul();
354   patterns[kNameFlashAttentionPatternForPanGu] = DefineFlashAttentionPatternForPanGu();
355   patterns[kNameFlashAttentionPatternForLLAMAPatternV1] = DefineFlashAttentionPatternForLLAMAPatternV1();
356   patterns[kNameFlashAttentionPatternForLLAMAPatternV2] = DefineFlashAttentionPatternForLLAMAPatternV2();
357   patterns[kNameFlashAttentionPatternForBaiChuan] = DefineFlashAttentionPatternForBaiChuan();
358   patterns[kNameFlashAttentionPatternForMsSDPseShift] = DefineFlashAttentionPatternForMsSDPseShift();
359   patterns[kNameFlashAttentionPatternForSDEinsum] = DefineFlashAttentionPatternForSDEinsum();
360   return patterns;
361 }
362 
CreatePadCNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,int32_t pad_size,const std::string & node_name) const363 CNodePtr FlashAttentionFusion::CreatePadCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int32_t pad_size,
364                                               const std::string &node_name) const {
365   MS_LOG(INFO) << "add pad node for prompt flash attention.";
366   if (node->fullname_with_scope().find(kNamePadNodeSuffix) != std::string::npos) {
367     MS_LOG(WARNING) << "node name contains " << kNamePadNodeSuffix << ", pad node is not created";
368     return node->cast<CNodePtr>();
369   }
370   auto pad_prim = std::make_shared<ops::PadFusion>();
371   if (pad_prim == nullptr) {
372     MS_LOG(ERROR) << "new pad prim failed, prim is nullptr.";
373     return nullptr;
374   }
375 
376   pad_prim->AddAttr("padding_mode", api::MakeValue(PaddingMode::CONSTANT));
377   pad_prim->AddAttr("constant_value", api::MakeValue(0.0));
378   std::vector<std::vector<int32_t>> paddings = {{0, 0}, {0, 0}, {0, 0}, {0, pad_size}};
379 
380   auto pad_prim_c = pad_prim->GetPrim();
381   if (pad_prim_c == nullptr) {
382     MS_LOG(WARNING) << "pad_prim_c is nullptr.";
383     return nullptr;
384   }
385   AnfNodePtr paddings_node = BuildIntVec2DParameterNode(
386     func_graph, paddings, node->fullname_with_scope() + std::to_string(kNameIndex) + "_paddings");
387   if (paddings_node == nullptr) {
388     MS_LOG(WARNING) << "paddings_node is nullptr.";
389     return nullptr;
390   }
391   auto inputs = {node, paddings_node};
392   auto pad_cnode = func_graph->NewCNode(pad_prim_c, inputs);
393   if (pad_cnode == nullptr) {
394     MS_LOG(ERROR) << "new pad cnode failed, cnode is nulpptr.";
395     return nullptr;
396   }
397   pad_cnode->set_fullname_with_scope(node->fullname_with_scope() + std::to_string(kNameIndex++) + kNamePadNodeSuffix);
398   if (node->abstract() != nullptr) {
399     pad_cnode->set_abstract(node->abstract()->Clone());
400   }
401   MS_LOG(INFO) << "create pad node end.";
402   return pad_cnode;
403 }
404 
CreateSliceCNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,int32_t slice_size) const405 CNodePtr FlashAttentionFusion::CreateSliceCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
406                                                 int32_t slice_size) const {
407   MS_LOG(INFO) << "add slice node for prompt flash attention.";
408   auto slice_prim = std::make_shared<ops::Slice>();
409   if (slice_prim == nullptr) {
410     MS_LOG(ERROR) << "new pad prim failed, prim is nullptr.";
411     return nullptr;
412   }
413 
414   std::vector<int32_t> begin = {0, 0, 0, 0};
415   std::vector<int32_t> size = {-1, -1, -1, slice_size};
416 
417   auto slice_prim_c = slice_prim->GetPrim();
418   if (slice_prim_c == nullptr) {
419     MS_LOG(ERROR) << "slice prim c is nullptr.";
420     return nullptr;
421   }
422 
423   AnfNodePtr begin_node = BuildIntVecParameterNode(func_graph, begin, node->fullname_with_scope() + "_begin");
424   if (begin_node == nullptr) {
425     MS_LOG(WARNING) << "BuildIntVecParameterNode failed.";
426     return nullptr;
427   }
428   AnfNodePtr size_node = BuildIntVecParameterNode(func_graph, size, node->fullname_with_scope() + "_size");
429   if (size_node == nullptr) {
430     MS_LOG(WARNING) << "BuildIntVecParameterNode failed.";
431     return nullptr;
432   }
433 
434   auto inputs = {node, begin_node, size_node};
435   auto slice_cnode = func_graph->NewCNode(slice_prim_c, inputs);
436   if (slice_cnode == nullptr) {
437     MS_LOG(WARNING) << "create slice_cnode failed.";
438     return nullptr;
439   }
440 
441   slice_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_fa_slice");
442   if (node->abstract() != nullptr) {
443     slice_cnode->set_abstract(node->abstract()->Clone());
444   }
445   MS_LOG(INFO) << "create slice node end.";
446   return slice_cnode;
447 }
448 
DefineFlashAttentionPatternForMsSD21() const449 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForMsSD21() const {
450   //  reshape Q
451   auto q_input = std::make_shared<Var>();
452   auto reshape_q_input_2 = std::make_shared<Var>();
453   MS_CHECK_TRUE_RET(reshape_q_input_2 != nullptr, {});
454   auto is_reshape_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
455   MS_CHECK_TRUE_RET(is_reshape_q != nullptr, {});
456   auto reshape_q = VectorRef({is_reshape_q, q_input, reshape_q_input_2});
457 
458   // transpose
459   auto k_input = std::make_shared<Var>();
460   auto is_transpose_param = std::make_shared<Var>();
461   MS_CHECK_TRUE_RET(is_transpose_param != nullptr, {});
462   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
463   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
464   auto transpose = VectorRef({is_transpose, k_input, is_transpose_param});
465 
466   // matmul 1
467   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
468   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
469   auto matmul_1 = VectorRef({is_matmul_1, reshape_q, transpose});
470   // q mul
471   auto is_mul_param = std::make_shared<Var>();
472   MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
473   auto is_mul_qk = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
474   MS_CHECK_TRUE_RET(is_mul_qk != nullptr, {});
475   auto mul_qk = VectorRef({is_mul_qk, matmul_1, is_mul_param});
476   // softmax
477   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
478   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
479   auto softmax = VectorRef({is_softmax, mul_qk});
480 
481   // matmul 2
482   auto v = std::make_shared<Var>();  // input V
483   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
484   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
485   auto matmul_2 = VectorRef({is_matmul_2, softmax, v});
486   return matmul_2;
487 }
488 
DefineFlashAttentionPatternForMsSDPseShift() const489 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForMsSDPseShift() const {
490   // Q
491   auto q_trans_input = std::make_shared<Var>();
492   MS_CHECK_TRUE_RET(q_trans_input != nullptr, {});
493   auto reshape_q_input_2 = std::make_shared<Var>();
494   MS_CHECK_TRUE_RET(reshape_q_input_2 != nullptr, {});
495   auto is_reshape_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
496   MS_CHECK_TRUE_RET(is_reshape_q != nullptr, {});
497   auto reshape_q = VectorRef({is_reshape_q, q_trans_input, reshape_q_input_2});
498 
499   // K
500   auto k_trans_input = std::make_shared<Var>();
501   MS_CHECK_TRUE_RET(k_trans_input != nullptr, {});
502   auto reshape_k_input_2 = std::make_shared<Var>();
503   MS_CHECK_TRUE_RET(reshape_k_input_2 != nullptr, {});
504   auto is_reshape_k = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
505   MS_CHECK_TRUE_RET(is_reshape_k != nullptr, {});
506   auto reshape_k = VectorRef({is_reshape_k, k_trans_input, reshape_k_input_2});
507   MS_CHECK_TRUE_RET(reshape_k != nullptr, {});
508   auto is_transpose_param = std::make_shared<Var>();
509   MS_CHECK_TRUE_RET(is_transpose_param != nullptr, {});
510   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
511   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
512   auto transpose = VectorRef({is_transpose, reshape_k, is_transpose_param});
513 
514   // matmul
515   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
516   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
517   auto matmul_1 = VectorRef({is_matmul_1, reshape_q, transpose});
518 
519   // mul
520   auto is_mul_param = std::make_shared<Var>();
521   MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
522   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
523   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
524   auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
525 
526   // add
527   auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>);
528   MS_CHECK_TRUE_RET(is_add != nullptr, {});
529   auto add_input_2 = std::make_shared<Var>();
530   MS_CHECK_TRUE_RET(add_input_2 != nullptr, {});
531   auto add = VectorRef({is_add, mul, add_input_2});
532 
533   // softmax
534   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
535   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
536   auto softmax = VectorRef({is_softmax, add});
537 
538   // cast
539   auto is_cast_param = std::make_shared<Var>();
540   MS_CHECK_TRUE_RET(is_cast_param != nullptr, {});
541   auto is_cast = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
542   MS_CHECK_TRUE_RET(is_cast != nullptr, {});
543   auto cast = VectorRef({is_cast, softmax, is_cast_param});
544 
545   // V
546   auto v_trans_input = std::make_shared<Var>();
547   MS_CHECK_TRUE_RET(v_trans_input != nullptr, {});
548   auto reshape_v_input_2 = std::make_shared<Var>();
549   MS_CHECK_TRUE_RET(reshape_v_input_2 != nullptr, {});
550   auto is_reshape_v = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
551   MS_CHECK_TRUE_RET(is_reshape_v != nullptr, {});
552   auto reshape_v = VectorRef({is_reshape_v, v_trans_input, reshape_v_input_2});
553 
554   // matmul 2
555   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
556   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
557   auto matmul_2 = VectorRef({is_matmul_2, cast, reshape_v});
558 
559   // reshape
560   auto reshape_output_input_2 = std::make_shared<Var>();
561   MS_CHECK_TRUE_RET(reshape_output_input_2 != nullptr, {});
562   auto is_reshape_output = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
563   MS_CHECK_TRUE_RET(is_reshape_output != nullptr, {});
564   auto reshape_output = VectorRef({is_reshape_output, matmul_2, reshape_output_input_2});
565   return reshape_output;
566 }
567 
DefineFlashAttentionPatternForMsSDXL() const568 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForMsSDXL() const {
569   // matmul 1
570   auto q = std::make_shared<Var>();  // input Q
571   MS_CHECK_TRUE_RET(q != nullptr, {});
572   auto k = std::make_shared<Var>();  // input K
573   MS_CHECK_TRUE_RET(k != nullptr, {});
574   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
575   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
576   auto matmul_1 = VectorRef({is_matmul_1, q, k});
577   // q div
578   auto is_div_q_param = std::make_shared<Var>();
579   MS_CHECK_TRUE_RET(is_div_q_param != nullptr, {});
580   auto is_div_q = std::make_shared<CondVar>(IsDivNode);
581   MS_CHECK_TRUE_RET(is_div_q != nullptr, {});
582   auto div_q = VectorRef({is_div_q, matmul_1, is_div_q_param});
583   // softmax
584   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
585   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
586   auto softmax = VectorRef({is_softmax, div_q});
587 
588   // matmul 2
589   auto v = std::make_shared<Var>();  // input V
590   MS_CHECK_TRUE_RET(v != nullptr, {});
591   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
592   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
593   auto matmul_2 = VectorRef({is_matmul_2, softmax, v});
594   return matmul_2;
595 }
596 
DefineFlashAttentionPatternForVideoComposer() const597 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForVideoComposer() const {
598   // q trans
599   auto input_q = std::make_shared<Var>();
600   MS_CHECK_TRUE_RET(input_q != nullptr, {});
601   auto input_q_perm = std::make_shared<Var>();
602   MS_CHECK_TRUE_RET(input_q_perm != nullptr, {});
603   auto is_q_transpese = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
604   MS_CHECK_TRUE_RET(is_q_transpese != nullptr, {});
605   auto q_transpose = VectorRef({is_q_transpese, input_q, input_q_perm});
606   // q reshape
607   auto reshape_q_input = std::make_shared<Var>();
608   MS_CHECK_TRUE_RET(reshape_q_input != nullptr, {});
609   auto is_q_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
610   MS_CHECK_TRUE_RET(is_q_reshape != nullptr, {});
611   auto reshape_q = VectorRef({is_q_reshape, q_transpose, reshape_q_input});
612   // k trans
613   auto input_k = std::make_shared<Var>();
614   MS_CHECK_TRUE_RET(input_k != nullptr, {});
615   auto input_k_perm = std::make_shared<Var>();
616   MS_CHECK_TRUE_RET(input_k_perm != nullptr, {});
617   auto is_k_transpese = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
618   MS_CHECK_TRUE_RET(is_k_transpese != nullptr, {});
619   auto k_transpose = VectorRef({is_k_transpese, input_k, input_k_perm});
620   // k reshape
621   auto reshape_k_input = std::make_shared<Var>();
622   MS_CHECK_TRUE_RET(reshape_k_input != nullptr, {});
623   auto is_k_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
624   MS_CHECK_TRUE_RET(is_k_reshape != nullptr, {});
625   auto reshape_k = VectorRef({is_k_reshape, k_transpose, reshape_k_input});
626   // k trans 2
627   auto input_k_trans_2_perm = std::make_shared<Var>();
628   MS_CHECK_TRUE_RET(input_k_trans_2_perm != nullptr, {});
629   auto is_k_transpese_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
630   MS_CHECK_TRUE_RET(is_k_transpese_2 != nullptr, {});
631   auto k_transpose_2 = VectorRef({is_k_transpese_2, reshape_k, input_k_trans_2_perm});
632 
633   // v trans
634   auto input_v = std::make_shared<Var>();
635   MS_CHECK_TRUE_RET(input_v != nullptr, {});
636   auto input_v_perm = std::make_shared<Var>();
637   MS_CHECK_TRUE_RET(input_v_perm != nullptr, {});
638   auto is_v_transpese = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
639   MS_CHECK_TRUE_RET(is_v_transpese != nullptr, {});
640   auto v_transpose = VectorRef({is_v_transpese, input_v, input_v_perm});
641   // v reshape
642   auto reshape_v_input = std::make_shared<Var>();
643   MS_CHECK_TRUE_RET(reshape_v_input != nullptr, {});
644   auto is_v_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
645   MS_CHECK_TRUE_RET(is_v_reshape != nullptr, {});
646   auto reshape_v = VectorRef({is_v_reshape, v_transpose, reshape_v_input});
647 
648   //  // matmul 1
649   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
650   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
651   auto matmul_1 = VectorRef({is_matmul_1, reshape_q, k_transpose_2});
652   // mul
653   auto is_mul_param = std::make_shared<Var>();
654   MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
655   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
656   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
657   auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
658 
659   // cast
660   auto is_cast_1_param = std::make_shared<Var>();
661   MS_CHECK_TRUE_RET(is_cast_1_param != nullptr, {});
662   auto is_cast_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
663   MS_CHECK_TRUE_RET(is_cast_1 != nullptr, {});
664   auto cast_1 = VectorRef({is_cast_1, mul, is_cast_1_param});
665 
666   // softmax
667   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
668   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
669   auto softmax = VectorRef({is_softmax, cast_1});
670   // cast
671   auto is_cast_param = std::make_shared<Var>();
672   MS_CHECK_TRUE_RET(is_cast_param != nullptr, {});
673   auto is_cast = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
674   MS_CHECK_TRUE_RET(is_cast != nullptr, {});
675   auto cast = VectorRef({is_cast, softmax, is_cast_param});
676   // matmul
677   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
678   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
679   auto matmul_2 = VectorRef({is_matmul_2, cast, reshape_v});
680 
681   // output reshape to four dims
682   auto reshape_o_2 = std::make_shared<Var>();
683   MS_CHECK_TRUE_RET(reshape_o_2 != nullptr, {});
684   auto is_reshape_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
685   MS_CHECK_TRUE_RET(is_reshape_o != nullptr, {});
686   auto reshape_o = VectorRef({is_reshape_o, matmul_2, reshape_o_2});
687   return reshape_o;
688 }
689 
DefineFlashAttentionPatternForSDBNSD() const690 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForSDBNSD() const {
691   // Q reshape
692   auto reshape_q_input_1 = std::make_shared<Var>();  // input Q
693   MS_CHECK_TRUE_RET(reshape_q_input_1 != nullptr, {});
694   auto reshape_q_input_2 = std::make_shared<Var>();
695   MS_CHECK_TRUE_RET(reshape_q_input_2 != nullptr, {});
696   auto is_reshape_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
697   MS_CHECK_TRUE_RET(is_reshape_q != nullptr, {});
698   auto reshape_q = VectorRef({is_reshape_q, reshape_q_input_1, reshape_q_input_2});
699   // K reshape
700   auto reshape_k_input_1 = std::make_shared<Var>();  // input K
701   MS_CHECK_TRUE_RET(reshape_k_input_1 != nullptr, {});
702   auto reshape_k_input_2 = std::make_shared<Var>();
703   MS_CHECK_TRUE_RET(reshape_k_input_2 != nullptr, {});
704   auto is_reshape_k = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
705   MS_CHECK_TRUE_RET(is_reshape_k != nullptr, {});
706   auto reshape_k = VectorRef({is_reshape_k, reshape_k_input_1, reshape_k_input_2});
707   // transpose
708   auto is_transpose_param = std::make_shared<CondVar>(IsParamNode);
709   MS_CHECK_TRUE_RET(is_transpose_param != nullptr, {});
710   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
711   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
712   auto transpose = VectorRef({is_transpose, reshape_k, is_transpose_param});
713   // matmul 1
714   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
715   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
716   auto matmul_1 = VectorRef({is_matmul_1, reshape_q, transpose});
717   // mul
718   auto is_mul_param = std::make_shared<CondVar>(IsParamNode);
719   MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
720   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
721   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
722   auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
723   // softmax
724   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
725   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
726   auto softmax = VectorRef({is_softmax, mul});
727   // cast
728   auto is_cast_param = std::make_shared<CondVar>(IsParamNode);
729   MS_CHECK_TRUE_RET(is_cast_param != nullptr, {});
730   auto is_cast = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
731   MS_CHECK_TRUE_RET(is_cast != nullptr, {});
732   auto cast = VectorRef({is_cast, softmax, is_cast_param});
733   // V reshape
734   auto reshape_v_input_1 = std::make_shared<Var>();  // input V
735   MS_CHECK_TRUE_RET(reshape_v_input_1 != nullptr, {});
736   auto reshape_v_input_2 = std::make_shared<Var>();
737   MS_CHECK_TRUE_RET(reshape_v_input_2 != nullptr, {});
738   auto is_reshape_v = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
739   MS_CHECK_TRUE_RET(is_reshape_v != nullptr, {});
740   auto reshape_v = VectorRef({is_reshape_v, reshape_v_input_1, reshape_v_input_2});
741   // matmul
742   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
743   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
744   auto matmul_2 = VectorRef({is_matmul_2, cast, reshape_v});
745   // output reshape to four dims
746   auto reshape_o_2 = std::make_shared<Var>();
747   MS_CHECK_TRUE_RET(reshape_o_2 != nullptr, {});
748   auto is_reshape_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
749   MS_CHECK_TRUE_RET(is_reshape_o != nullptr, {});
750   auto reshape_o = VectorRef({is_reshape_o, matmul_2, reshape_o_2});
751   return reshape_o;
752 }
753 
DefineFlashAttentionPatternForSDBSH() const754 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForSDBSH() const {
755   // Q: three dim input reshape to four dims
756   auto input_q_reshape_param_1 = std::make_shared<Var>();
757   MS_CHECK_TRUE_RET(input_q_reshape_param_1 != nullptr, {});
758   auto input_q_reshape_param_2 = std::make_shared<Var>();
759   MS_CHECK_TRUE_RET(input_q_reshape_param_2 != nullptr, {});
760   auto is_input_q_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
761   MS_CHECK_TRUE_RET(is_input_q_reshape != nullptr, {});
762   auto input_q_reshape = VectorRef({is_input_q_reshape, input_q_reshape_param_1, input_q_reshape_param_2});
763   //  transpose
764   auto is_input_q_transpose_param = std::make_shared<Var>();
765   MS_CHECK_TRUE_RET(is_input_q_transpose_param != nullptr, {});
766   auto is_input_q_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
767   MS_CHECK_TRUE_RET(is_input_q_transpose != nullptr, {});
768   auto input_q_transpose = VectorRef({is_input_q_transpose, input_q_reshape, is_input_q_transpose_param});
769   // Q reshape
770   auto reshape_q_input_2 = std::make_shared<Var>();
771   MS_CHECK_TRUE_RET(reshape_q_input_2 != nullptr, {});
772   auto is_reshape_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
773   MS_CHECK_TRUE_RET(is_reshape_q != nullptr, {});
774   auto reshape_q = VectorRef({is_reshape_q, input_q_transpose, reshape_q_input_2});
775 
776   // K: three dim input reshape to four dims
777   auto input_k_reshape_param_1 = std::make_shared<Var>();
778   MS_CHECK_TRUE_RET(input_k_reshape_param_1 != nullptr, {});
779   auto input_k_reshape_param_2 = std::make_shared<Var>();
780   MS_CHECK_TRUE_RET(input_k_reshape_param_2 != nullptr, {});
781   auto is_input_k_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
782   MS_CHECK_TRUE_RET(is_input_k_reshape != nullptr, {});
783   auto input_k_reshape = VectorRef({is_input_k_reshape, input_k_reshape_param_1, input_k_reshape_param_2});
784   //  transpose
785   auto is_input_k_transpose_param = std::make_shared<Var>();
786   MS_CHECK_TRUE_RET(is_input_k_transpose_param != nullptr, {});
787   auto is_input_k_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
788   MS_CHECK_TRUE_RET(is_input_k_transpose != nullptr, {});
789   auto input_k_transpose = VectorRef({is_input_k_transpose, input_k_reshape, is_input_k_transpose_param});
790   // K reshape
791   auto reshape_k_input_2 = std::make_shared<Var>();
792   MS_CHECK_TRUE_RET(reshape_k_input_2 != nullptr, {});
793   auto is_reshape_k = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
794   MS_CHECK_TRUE_RET(is_reshape_k != nullptr, {});
795   auto reshape_k = VectorRef({is_reshape_k, input_k_transpose, reshape_k_input_2});
796   // transpose
797   auto is_transpose_param = std::make_shared<CondVar>(IsParamNode);
798   MS_CHECK_TRUE_RET(is_transpose_param != nullptr, {});
799   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
800   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
801   auto transpose = VectorRef({is_transpose, reshape_k, is_transpose_param});
802   // matmul 1
803   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
804   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
805   auto matmul_1 = VectorRef({is_matmul_1, reshape_q, transpose});
806   // mul
807   auto is_mul_param = std::make_shared<CondVar>(IsParamNode);
808   MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
809   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
810   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
811   auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
812   // softmax
813   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
814   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
815   auto softmax = VectorRef({is_softmax, mul});
816   // cast
817   auto is_cast_param = std::make_shared<CondVar>(IsParamNode);
818   MS_CHECK_TRUE_RET(is_cast_param != nullptr, {});
819   auto is_cast = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
820   MS_CHECK_TRUE_RET(is_cast != nullptr, {});
821   auto cast = VectorRef({is_cast, softmax, is_cast_param});
822 
823   // V: three dim input reshape to four dims
824   auto input_v_reshape_param_1 = std::make_shared<Var>();
825   MS_CHECK_TRUE_RET(input_v_reshape_param_1 != nullptr, {});
826   auto input_v_reshape_param_2 = std::make_shared<Var>();
827   MS_CHECK_TRUE_RET(input_v_reshape_param_2 != nullptr, {});
828   auto is_input_v_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
829   MS_CHECK_TRUE_RET(is_input_v_reshape != nullptr, {});
830   auto input_v_reshape = VectorRef({is_input_v_reshape, input_v_reshape_param_1, input_v_reshape_param_2});
831   //  transpose
832   auto is_input_v_transpose_param = std::make_shared<CondVar>(IsParamNode);
833   MS_CHECK_TRUE_RET(is_input_v_transpose_param != nullptr, {});
834   auto is_input_v_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
835   MS_CHECK_TRUE_RET(is_input_v_transpose != nullptr, {});
836   auto input_v_transpose = VectorRef({is_input_v_transpose, input_v_reshape, is_input_v_transpose_param});
837   // V reshape
838   auto reshape_v_input_2 = std::make_shared<Var>();
839   MS_CHECK_TRUE_RET(reshape_v_input_2 != nullptr, {});
840   auto is_reshape_v = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
841   MS_CHECK_TRUE_RET(is_reshape_v != nullptr, {});
842   auto reshape_v = VectorRef({is_reshape_v, input_v_transpose, reshape_v_input_2});
843   // matmul
844   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
845   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
846   auto matmul_2 = VectorRef({is_matmul_2, cast, reshape_v});
847   // output reshape to four dims
848   auto reshape_o_2 = std::make_shared<Var>();
849   MS_CHECK_TRUE_RET(reshape_o_2 != nullptr, {});
850   auto is_reshape_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
851   MS_CHECK_TRUE_RET(is_reshape_o != nullptr, {});
852   auto reshape_o = VectorRef({is_reshape_o, matmul_2, reshape_o_2});
853   // output transpose
854   auto is_transpose_o_param = std::make_shared<CondVar>(IsParamNode);
855   MS_CHECK_TRUE_RET(is_transpose_o_param != nullptr, {});
856   auto is_transpose_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
857   MS_CHECK_TRUE_RET(is_transpose_o != nullptr, {});
858   auto transpose_o = VectorRef({is_transpose_o, reshape_o, is_transpose_o_param});
859   // output reshape to three dims
860   auto reshape_o2_2 = std::make_shared<Var>();
861   MS_CHECK_TRUE_RET(reshape_o2_2 != nullptr, {});
862   auto is_reshape_o2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
863   MS_CHECK_TRUE_RET(is_reshape_o2 != nullptr, {});
864   auto reshape_o2 = VectorRef({is_reshape_o2, transpose_o, reshape_o2_2});
865   return reshape_o2;
866 }
867 
DefineFlashAttentionPatternForSDPreMul() const868 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForSDPreMul() const {
869   // Q
870   auto q_input = std::make_shared<Var>();  // input Q
871   MS_CHECK_TRUE_RET(q_input != nullptr, {});
872   // mul
873   auto mul_q_val = std::make_shared<Var>();
874   MS_CHECK_TRUE_RET(mul_q_val != nullptr, {});
875   auto is_mul_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
876   MS_CHECK_TRUE_RET(is_mul_q != nullptr, {});
877   auto mul_q = VectorRef({is_mul_q, q_input, mul_q_val});
878 
879   // K
880   auto k_input = std::make_shared<Var>();  // input Q
881   MS_CHECK_TRUE_RET(k_input != nullptr, {});
882   // mul
883   auto mul_k_val = std::make_shared<Var>();
884   MS_CHECK_TRUE_RET(mul_k_val != nullptr, {});
885   auto is_mul_k = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
886   MS_CHECK_TRUE_RET(is_mul_k != nullptr, {});
887   auto mul_k = VectorRef({is_mul_k, k_input, mul_k_val});
888 
889   // matmul 1
890   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
891   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
892   auto matmul_1 = VectorRef({is_matmul_1, mul_q, mul_k});
893 
894   // softmax
895   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
896   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
897   auto softmax = VectorRef({is_softmax, matmul_1});
898 
899   // matmul
900   auto mul_v = std::make_shared<Var>();
901   MS_CHECK_TRUE_RET(mul_v != nullptr, {});
902   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
903   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
904   auto matmul_2 = VectorRef({is_matmul_2, softmax, mul_v});
905 
906   auto output_trans_input_2 = std::make_shared<Var>();
907   MS_CHECK_TRUE_RET(output_trans_input_2 != nullptr, {});
908   auto is_trans_output = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
909   MS_CHECK_TRUE_RET(is_trans_output != nullptr, {});
910   auto transpose_output = VectorRef({is_trans_output, matmul_2, output_trans_input_2});
911   MS_CHECK_TRUE_RET(transpose_output != nullptr, {});
912 
913   auto output_reshape_input_2 = std::make_shared<Var>();
914   MS_CHECK_TRUE_RET(output_reshape_input_2 != nullptr, {});
915   auto is_reshape_output = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
916   MS_CHECK_TRUE_RET(is_reshape_output != nullptr, {});
917   auto output_reshape = VectorRef({is_reshape_output, transpose_output, output_reshape_input_2});
918   MS_CHECK_TRUE_RET(output_reshape != nullptr, {});
919   return output_reshape;
920 }
921 
DefineFlashAttentionPatternForSDWithoutCast() const922 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForSDWithoutCast() const {
923   // Q: three dim input reshape to four dims
924   auto input_q_reshape_param_1 = std::make_shared<Var>();
925   MS_CHECK_TRUE_RET(input_q_reshape_param_1 != nullptr, {});
926   auto input_q_reshape_param_2 = std::make_shared<Var>();
927   MS_CHECK_TRUE_RET(input_q_reshape_param_2 != nullptr, {});
928   auto is_input_q_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
929   MS_CHECK_TRUE_RET(is_input_q_reshape != nullptr, {});
930   auto input_q_reshape = VectorRef({is_input_q_reshape, input_q_reshape_param_1, input_q_reshape_param_2});
931   //  transpose
932   auto is_input_q_transpose_param = std::make_shared<Var>();
933   MS_CHECK_TRUE_RET(is_input_q_transpose_param != nullptr, {});
934   auto is_input_q_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
935   MS_CHECK_TRUE_RET(is_input_q_transpose != nullptr, {});
936   auto input_q_transpose = VectorRef({is_input_q_transpose, input_q_reshape, is_input_q_transpose_param});
937   // Q reshape
938   auto reshape_q_input_2 = std::make_shared<Var>();
939   MS_CHECK_TRUE_RET(reshape_q_input_2 != nullptr, {});
940   auto is_reshape_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
941   MS_CHECK_TRUE_RET(is_reshape_q != nullptr, {});
942   auto reshape_q = VectorRef({is_reshape_q, input_q_transpose, reshape_q_input_2});
943 
944   // K: three dim input reshape to four dims
945   auto input_k_reshape_param_1 = std::make_shared<Var>();
946   MS_CHECK_TRUE_RET(input_k_reshape_param_1 != nullptr, {});
947   auto input_k_reshape_param_2 = std::make_shared<Var>();
948   MS_CHECK_TRUE_RET(input_k_reshape_param_2 != nullptr, {});
949   auto is_input_k_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
950   MS_CHECK_TRUE_RET(is_input_k_reshape != nullptr, {});
951   auto input_k_reshape = VectorRef({is_input_k_reshape, input_k_reshape_param_1, input_k_reshape_param_2});
952   //  transpose
953   auto is_input_k_transpose_param = std::make_shared<Var>();
954   MS_CHECK_TRUE_RET(is_input_k_transpose_param != nullptr, {});
955   auto is_input_k_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
956   MS_CHECK_TRUE_RET(is_input_k_transpose != nullptr, {});
957   auto input_k_transpose = VectorRef({is_input_k_transpose, input_k_reshape, is_input_k_transpose_param});
958   // K reshape
959   auto reshape_k_input_2 = std::make_shared<Var>();
960   MS_CHECK_TRUE_RET(reshape_k_input_2 != nullptr, {});
961   auto is_reshape_k = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
962   MS_CHECK_TRUE_RET(is_reshape_k != nullptr, {});
963   auto reshape_k = VectorRef({is_reshape_k, input_k_transpose, reshape_k_input_2});
964   // transpose
965   auto is_transpose_param = std::make_shared<CondVar>(IsParamNode);
966   MS_CHECK_TRUE_RET(is_transpose_param != nullptr, {});
967   auto is_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
968   MS_CHECK_TRUE_RET(is_transpose != nullptr, {});
969   auto transpose = VectorRef({is_transpose, reshape_k, is_transpose_param});
970   // matmul 1
971   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
972   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
973   auto matmul_1 = VectorRef({is_matmul_1, reshape_q, transpose});
974   // mul
975   auto is_mul_param = std::make_shared<CondVar>(IsParamNode);
976   MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
977   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
978   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
979   auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
980   // softmax
981   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
982   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
983   auto softmax = VectorRef({is_softmax, mul});
984 
985   // V: three dim input reshape to four dims
986   auto input_v_reshape_param_1 = std::make_shared<Var>();
987   MS_CHECK_TRUE_RET(input_v_reshape_param_1 != nullptr, {});
988   auto input_v_reshape_param_2 = std::make_shared<Var>();
989   MS_CHECK_TRUE_RET(input_v_reshape_param_2 != nullptr, {});
990   auto is_input_v_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
991   MS_CHECK_TRUE_RET(is_input_v_reshape != nullptr, {});
992   auto input_v_reshape = VectorRef({is_input_v_reshape, input_v_reshape_param_1, input_v_reshape_param_2});
993   //  transpose
994   auto is_input_v_transpose_param = std::make_shared<CondVar>(IsParamNode);
995   MS_CHECK_TRUE_RET(is_input_v_transpose_param != nullptr, {});
996   auto is_input_v_transpose = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
997   MS_CHECK_TRUE_RET(is_input_v_transpose != nullptr, {});
998   auto input_v_transpose = VectorRef({is_input_v_transpose, input_v_reshape, is_input_v_transpose_param});
999   // V reshape
1000   auto reshape_v_input_2 = std::make_shared<Var>();
1001   MS_CHECK_TRUE_RET(reshape_v_input_2 != nullptr, {});
1002   auto is_reshape_v = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1003   MS_CHECK_TRUE_RET(is_reshape_v != nullptr, {});
1004   auto reshape_v = VectorRef({is_reshape_v, input_v_transpose, reshape_v_input_2});
1005   // matmul
1006   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
1007   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1008   auto matmul_2 = VectorRef({is_matmul_2, softmax, reshape_v});
1009   // output reshape to four dims
1010   auto reshape_o_2 = std::make_shared<Var>();
1011   MS_CHECK_TRUE_RET(reshape_o_2 != nullptr, {});
1012   auto is_reshape_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1013   MS_CHECK_TRUE_RET(is_reshape_o != nullptr, {});
1014   auto reshape_o = VectorRef({is_reshape_o, matmul_2, reshape_o_2});
1015   // output transpose
1016   auto is_transpose_o_param = std::make_shared<CondVar>(IsParamNode);
1017   MS_CHECK_TRUE_RET(is_transpose_o_param != nullptr, {});
1018   auto is_transpose_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
1019   MS_CHECK_TRUE_RET(is_transpose_o != nullptr, {});
1020   auto transpose_o = VectorRef({is_transpose_o, reshape_o, is_transpose_o_param});
1021   // output reshape to three dims
1022   auto reshape_o2_2 = std::make_shared<Var>();
1023   MS_CHECK_TRUE_RET(reshape_o2_2 != nullptr, {});
1024   auto is_reshape_o2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1025   MS_CHECK_TRUE_RET(is_reshape_o2 != nullptr, {});
1026   auto reshape_o2 = VectorRef({is_reshape_o2, transpose_o, reshape_o2_2});
1027   return reshape_o2;
1028 }
1029 
DefineFlashAttentionPatternForPanGu() const1030 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForPanGu() const {
1031   // q div
1032   auto q = std::make_shared<Var>();  // input Q
1033   MS_CHECK_TRUE_RET(q != nullptr, {});
1034   auto is_div_q_param = std::make_shared<Var>();
1035   MS_CHECK_TRUE_RET(is_div_q_param != nullptr, {});
1036   auto is_div_q = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimRealDiv>);
1037   MS_CHECK_TRUE_RET(is_div_q != nullptr, {});
1038   auto div_q = VectorRef({is_div_q, q, is_div_q_param});
1039   // matmul 1
1040   auto k = std::make_shared<Var>();  // input K
1041   MS_CHECK_TRUE_RET(k != nullptr, {});
1042   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1043   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
1044   auto matmul_1 = VectorRef({is_matmul_1, div_q, k});
1045   // cast 1
1046   auto is_cast_1_param = std::make_shared<Var>();
1047   MS_CHECK_TRUE_RET(is_cast_1_param != nullptr, {});
1048   auto is_cast_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
1049   MS_CHECK_TRUE_RET(is_cast_1 != nullptr, {});
1050   auto cast_1 = VectorRef({is_cast_1, matmul_1, is_cast_1_param});
1051   // ===== attention mask =====
1052   // sub
1053   auto atten_mask = std::make_shared<Var>();
1054   MS_CHECK_TRUE_RET(atten_mask != nullptr, {});
1055   // mul
1056   auto is_mask_mul_param = std::make_shared<Var>();
1057   MS_CHECK_TRUE_RET(is_mask_mul_param != nullptr, {});
1058   auto is_mask_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1059   MS_CHECK_TRUE_RET(is_mask_mul != nullptr, {});
1060   auto mask_mul = VectorRef({is_mask_mul, atten_mask, is_mask_mul_param});
1061   // ===== end attention mask =====
1062   // add
1063   auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAdd>);
1064   MS_CHECK_TRUE_RET(is_add != nullptr, {});
1065   auto add = VectorRef({is_add, mask_mul, cast_1});
1066   // softmax
1067   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
1068   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
1069   auto softmax = VectorRef({is_softmax, add});
1070   // cast 2
1071   auto is_cast_2_param = std::make_shared<Var>();
1072   MS_CHECK_TRUE_RET(is_cast_2_param != nullptr, {});
1073   auto is_cast_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
1074   MS_CHECK_TRUE_RET(is_cast_2 != nullptr, {});
1075   auto cast_2 = VectorRef({is_cast_2, softmax, is_cast_2_param});
1076 
1077   // matmul 2
1078   auto v = std::make_shared<Var>();  // input V
1079   MS_CHECK_TRUE_RET(v != nullptr, {});
1080   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1081   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1082   auto matmul_2 = VectorRef({is_matmul_2, cast_2, v});
1083   return matmul_2;
1084 }
1085 
DefineFlashAttentionPatternForLLAMAPatternV1() const1086 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForLLAMAPatternV1() const {
1087   // matmul 1
1088   auto matmul_1_q_input = std::make_shared<Var>();  // input Q
1089   MS_CHECK_TRUE_RET(matmul_1_q_input != nullptr, {});
1090   auto matmul_1_k_input = std::make_shared<Var>();  // input K
1091   MS_CHECK_TRUE_RET(matmul_1_k_input != nullptr, {});
1092   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1093   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
1094   auto matmul_1 = VectorRef({is_matmul_1, matmul_1_q_input, matmul_1_k_input});
1095   // mul
1096   auto is_mul_param = std::make_shared<Var>();
1097   MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
1098   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1099   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
1100   auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
1101   // ===== attention mask =====
1102   // sub
1103   auto sub_mask_input_1 = std::make_shared<Var>();  // input attention mask
1104   MS_CHECK_TRUE_RET(sub_mask_input_1 != nullptr, {});
1105   // mul
1106   auto is_mask_mul_param = std::make_shared<Var>();
1107   MS_CHECK_TRUE_RET(is_mask_mul_param != nullptr, {});
1108   auto is_mask_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1109   MS_CHECK_TRUE_RET(is_mask_mul != nullptr, {});
1110   auto mask_mul = VectorRef({is_mask_mul, sub_mask_input_1, is_mask_mul_param});
1111   // ===== end attention mask =====
1112   // add
1113   auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAdd>);
1114   MS_CHECK_TRUE_RET(is_add != nullptr, {});
1115   auto add = VectorRef({is_add, mask_mul, mul});
1116   // cast 1
1117   auto is_cast_1_param = std::make_shared<Var>();
1118   MS_CHECK_TRUE_RET(is_cast_1_param != nullptr, {});
1119   auto is_cast_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
1120   MS_CHECK_TRUE_RET(is_cast_1 != nullptr, {});
1121   auto cast_1 = VectorRef({is_cast_1, add, is_cast_1_param});
1122   // softmax
1123   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
1124   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
1125   auto softmax = VectorRef({is_softmax, cast_1});
1126   // cast 2
1127   auto is_cast_2_param = std::make_shared<Var>();
1128   MS_CHECK_TRUE_RET(is_cast_2_param != nullptr, {});
1129   auto is_cast_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimCast>);
1130   MS_CHECK_TRUE_RET(is_cast_2 != nullptr, {});
1131   auto cast_2 = VectorRef({is_cast_2, softmax, is_cast_2_param});
1132   // matmul
1133   auto v_input = std::make_shared<Var>();  // input V
1134   MS_CHECK_TRUE_RET(v_input != nullptr, {});
1135   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1136   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1137   auto matmul_2 = VectorRef({is_matmul_2, cast_2, v_input});
1138   return matmul_2;
1139 }
1140 
DefineFlashAttentionPatternForLLAMAPatternV2() const1141 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForLLAMAPatternV2() const {
1142   // matmul 1
1143   auto matmul_1_q_input = std::make_shared<Var>();  // input Q
1144   MS_CHECK_TRUE_RET(matmul_1_q_input != nullptr, {});
1145   auto matmul_1_k_input = std::make_shared<Var>();  // input K
1146   MS_CHECK_TRUE_RET(matmul_1_k_input != nullptr, {});
1147   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1148   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
1149   auto matmul_1 = VectorRef({is_matmul_1, matmul_1_q_input, matmul_1_k_input});
1150   // mul
1151   auto is_mul_param = std::make_shared<Var>();
1152   MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
1153   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1154   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
1155   auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
1156   // ===== attention mask =====
1157   // sub
1158   auto sub_mask_input_1 = std::make_shared<Var>();  // input attention mask
1159   MS_CHECK_TRUE_RET(sub_mask_input_1 != nullptr, {});
1160   // mul
1161   auto is_mask_mul_param = std::make_shared<Var>();
1162   MS_CHECK_TRUE_RET(is_mask_mul_param != nullptr, {});
1163   auto is_mask_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1164   MS_CHECK_TRUE_RET(is_mask_mul != nullptr, {});
1165   auto mask_mul = VectorRef({is_mask_mul, sub_mask_input_1, is_mask_mul_param});
1166   // ===== end attention mask =====
1167   // add
1168   auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAdd>);
1169   MS_CHECK_TRUE_RET(is_add != nullptr, {});
1170   auto add = VectorRef({is_add, mask_mul, mul});
1171   // softmax
1172   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
1173   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
1174   auto softmax = VectorRef({is_softmax, add});
1175   // matmul
1176   auto v_input = std::make_shared<Var>();  // input V
1177   MS_CHECK_TRUE_RET(v_input != nullptr, {});
1178   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1179   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1180   auto matmul_2 = VectorRef({is_matmul_2, softmax, v_input});
1181   return matmul_2;
1182 }
1183 
DefineFlashAttentionPatternForBaiChuan() const1184 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForBaiChuan() const {
1185   // matmul 1
1186   auto matmul_1_q_input = std::make_shared<Var>();  // input Q
1187   MS_CHECK_TRUE_RET(matmul_1_q_input != nullptr, {});
1188   auto matmul_1_k_input = std::make_shared<Var>();  // input K
1189   MS_CHECK_TRUE_RET(matmul_1_k_input != nullptr, {});
1190   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1191   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
1192   auto matmul_1 = VectorRef({is_matmul_1, matmul_1_q_input, matmul_1_k_input});
1193   // mul
1194   auto is_mul_param = std::make_shared<Var>();
1195   MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
1196   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1197   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
1198   auto mul = VectorRef({is_mul, matmul_1, is_mul_param});
1199   // ===== attention mask =====
1200   // mul
1201   auto is_mask_mul_param1 = std::make_shared<Var>();
1202   MS_CHECK_TRUE_RET(is_mask_mul_param1 != nullptr, {});
1203   auto is_mask_mul_param2 = std::make_shared<Var>();
1204   MS_CHECK_TRUE_RET(is_mask_mul_param2 != nullptr, {});
1205   auto is_mask_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMul>);
1206   MS_CHECK_TRUE_RET(is_mask_mul != nullptr, {});
1207   auto mask_mul = VectorRef({is_mask_mul, is_mask_mul_param1, is_mask_mul_param2});
1208   // ===== end attention mask =====
1209   // add
1210   auto is_add_param = std::make_shared<Var>();
1211   MS_CHECK_TRUE_RET(is_add_param != nullptr, {});
1212   auto is_add = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAdd>);
1213   MS_CHECK_TRUE_RET(is_add != nullptr, {});
1214   auto add = VectorRef({is_add, mul, is_add_param});
1215   // add for mask
1216   auto is_add_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAdd>);
1217   MS_CHECK_TRUE_RET(is_add_2 != nullptr, {});
1218   auto add_2 = VectorRef({is_add, mask_mul, add});
1219   // softmax
1220   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
1221   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
1222   auto softmax = VectorRef({is_softmax, add_2});
1223   // matmul
1224   auto v_input = std::make_shared<Var>();  // input V
1225   MS_CHECK_TRUE_RET(v_input != nullptr, {});
1226   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimBatchMatMul>);
1227   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1228   auto matmul_2 = VectorRef({is_matmul_2, softmax, v_input});
1229   return matmul_2;
1230 }
1231 
DefineFlashAttentionPatternForSDEinsum() const1232 const VectorRef FlashAttentionFusion::DefineFlashAttentionPatternForSDEinsum() const {
1233   // Q reshape
1234   auto input_q_reshape_param_1 = std::make_shared<Var>();
1235   MS_CHECK_TRUE_RET(input_q_reshape_param_1 != nullptr, {});
1236   auto input_q_reshape_param_2 = std::make_shared<Var>();
1237   MS_CHECK_TRUE_RET(input_q_reshape_param_2 != nullptr, {});
1238   auto is_input_q_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1239   MS_CHECK_TRUE_RET(is_input_q_reshape != nullptr, {});
1240   auto q_reshape = VectorRef({is_input_q_reshape, input_q_reshape_param_1, input_q_reshape_param_2});
1241 
1242   // K reshape
1243   auto input_k_reshape_param_1 = std::make_shared<Var>();
1244   MS_CHECK_TRUE_RET(input_k_reshape_param_1 != nullptr, {});
1245   auto input_k_reshape_param_2 = std::make_shared<Var>();
1246   MS_CHECK_TRUE_RET(input_k_reshape_param_2 != nullptr, {});
1247   auto is_input_k_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1248   MS_CHECK_TRUE_RET(is_input_k_reshape != nullptr, {});
1249   auto k_reshape = VectorRef({is_input_k_reshape, input_k_reshape_param_1, input_k_reshape_param_2});
1250 
1251   // matmul 1, einsum is replaced in onnx_einsum_adjust.cc
1252   auto is_matmul_1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
1253   MS_CHECK_TRUE_RET(is_matmul_1 != nullptr, {});
1254   auto matuml_1 = VectorRef({is_matmul_1, q_reshape, k_reshape});
1255   // mul
1256   auto is_mul_param = std::make_shared<CondVar>(IsParamNode);
1257   MS_CHECK_TRUE_RET(is_mul_param != nullptr, {});
1258   auto is_mul = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>);
1259   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
1260   auto mul = VectorRef({is_mul, matuml_1, is_mul_param});
1261   // softmax
1262   auto is_softmax = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSoftmax>);
1263   MS_CHECK_TRUE_RET(is_softmax != nullptr, {});
1264   auto softmax = VectorRef({is_softmax, mul});
1265 
1266   // V reshape
1267   auto input_v_reshape_param_1 = std::make_shared<Var>();
1268   MS_CHECK_TRUE_RET(input_v_reshape_param_1 != nullptr, {});
1269   auto input_v_reshape_param_2 = std::make_shared<Var>();
1270   MS_CHECK_TRUE_RET(input_v_reshape_param_2 != nullptr, {});
1271   auto is_input_v_reshape = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1272   MS_CHECK_TRUE_RET(is_input_v_reshape != nullptr, {});
1273   auto v_reshape = VectorRef({is_input_v_reshape, input_v_reshape_param_1, input_v_reshape_param_2});
1274 
1275   // matmul 2
1276   auto is_matmul_2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMatMulFusion>);
1277   MS_CHECK_TRUE_RET(is_matmul_2 != nullptr, {});
1278   auto matmul_2 = VectorRef({is_matmul_2, softmax, v_reshape});
1279   // output reshape to four dims
1280   auto reshape_o_2 = std::make_shared<Var>();
1281   MS_CHECK_TRUE_RET(reshape_o_2 != nullptr, {});
1282   auto is_reshape_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1283   MS_CHECK_TRUE_RET(is_reshape_o != nullptr, {});
1284   auto reshape_o = VectorRef({is_reshape_o, matmul_2, reshape_o_2});
1285   // output transpose
1286   auto is_transpose_o_param = std::make_shared<CondVar>(IsParamNode);
1287   MS_CHECK_TRUE_RET(is_transpose_o_param != nullptr, {});
1288   auto is_transpose_o = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimTranspose>);
1289   MS_CHECK_TRUE_RET(is_transpose_o != nullptr, {});
1290   auto transpose_o = VectorRef({is_transpose_o, reshape_o, is_transpose_o_param});
1291   // output reshape to three dims
1292   auto reshape_o2_2 = std::make_shared<Var>();
1293   MS_CHECK_TRUE_RET(reshape_o2_2 != nullptr, {});
1294   auto is_reshape_o2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimReshape>);
1295   MS_CHECK_TRUE_RET(is_reshape_o2 != nullptr, {});
1296   auto reshape_o2 = VectorRef({is_reshape_o2, transpose_o, reshape_o2_2});
1297   return reshape_o2;
1298 }
1299 
CreatePromptFlashAttentionCnodeForBNSD(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q,const AnfNodePtr & k,const AnfNodePtr & v,const AnfNodePtr & atten_mask,int64_t num_heads,int64_t next_token,float scale_value,const std::shared_ptr<FlashAttentionParm> & fa_parm,int64_t num_key_value_heads) const1300 CNodePtr FlashAttentionFusion::CreatePromptFlashAttentionCnodeForBNSD(
1301   const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
1302   const AnfNodePtr &atten_mask, int64_t num_heads, int64_t next_token, float scale_value,
1303   const std::shared_ptr<FlashAttentionParm> &fa_parm, int64_t num_key_value_heads) const {
1304   MS_LOG(INFO) << "CreatePromptFlashAttentionCnodeForBNSD";
1305   MS_LOG(INFO) << "num heads: " << num_heads << ", input layout: BNSD, next tokens: " << next_token
1306                << ", scale value: " << scale_value << ", num_key_value_heads: " << num_key_value_heads;
1307   MS_LOG(INFO) << "q name: " << q->fullname_with_scope() << ", k name: " << k->fullname_with_scope()
1308                << ", v name: " << v->fullname_with_scope();
1309   if (num_heads < 0 || scale_value < 0 || next_token < 0 || num_key_value_heads < 0) {
1310     MS_LOG(WARNING) << "shape is invalid";
1311     return nullptr;
1312   }
1313   if (fa_parm == nullptr) {
1314     MS_LOG(WARNING) << "FA parameter is null, please check.";
1315     return nullptr;
1316   }
1317   // create op
1318   auto prompt_flash_attention_prim = std::make_shared<ops::PromptFlashAttention>();
1319   if (prompt_flash_attention_prim == nullptr) {
1320     MS_LOG(ERROR) << "new prompt flash attention prim failed.";
1321     return nullptr;
1322   }
1323   // add attr
1324   prompt_flash_attention_prim->AddAttr("num_heads", api::MakeValue(num_heads));
1325   prompt_flash_attention_prim->AddAttr("input_layout", api::MakeValue("BNSD"));
1326   prompt_flash_attention_prim->AddAttr("next_tokens", api::MakeValue(next_token));
1327   prompt_flash_attention_prim->AddAttr("scale_value", api::MakeValue(scale_value));
1328   prompt_flash_attention_prim->AddAttr("num_key_value_heads", api::MakeValue(num_key_value_heads));
1329   prompt_flash_attention_prim->AddAttr("inner_precise", api::MakeValue(fa_parm->inner_precise));
1330   prompt_flash_attention_prim->AddAttr("sparse_mode", api::MakeValue(fa_parm->sparse_mode));
1331 
1332   auto fa_prim_c = prompt_flash_attention_prim->GetPrim();
1333   if (fa_prim_c == nullptr) {
1334     MS_LOG(ERROR) << "fa_prim_c is nullptr.";
1335     return nullptr;
1336   }
1337   CNodePtr prompt_flash_attention_cnode = nullptr;
1338   if (atten_mask != nullptr) {
1339     prompt_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v, atten_mask});
1340   } else {
1341     prompt_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v});
1342   }
1343   if (prompt_flash_attention_cnode == nullptr) {
1344     MS_LOG(ERROR) << "new cnode failed.";
1345     return nullptr;
1346   }
1347   prompt_flash_attention_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_prompt_flash_attention_bnsd");
1348   if (node->abstract() != nullptr) {
1349     prompt_flash_attention_cnode->set_abstract(node->abstract()->Clone());
1350   }
1351   MS_LOG(INFO) << "create PromptFlashAttention success.";
1352   return prompt_flash_attention_cnode;
1353 }
1354 
CreatePromptFlashAttentionCnodeForBNSDWithPse(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q,const AnfNodePtr & k,const AnfNodePtr & v,const AnfNodePtr & atten_mask,const AnfNodePtr & pse,int64_t num_heads,int64_t next_token,float scale_value,const std::shared_ptr<FlashAttentionParm> & fa_parm,int64_t num_key_value_heads) const1355 CNodePtr FlashAttentionFusion::CreatePromptFlashAttentionCnodeForBNSDWithPse(
1356   const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
1357   const AnfNodePtr &atten_mask, const AnfNodePtr &pse, int64_t num_heads, int64_t next_token, float scale_value,
1358   const std::shared_ptr<FlashAttentionParm> &fa_parm, int64_t num_key_value_heads) const {
1359   MS_LOG(INFO) << "CreatePromptFlashAttentionCnodeForBNSD with pse";
1360   MS_LOG(INFO) << "num heads: " << num_heads << ", input layout: BNSD, next tokens: " << next_token
1361                << ", scale value: " << scale_value << ", num_key_value_heads: " << num_key_value_heads;
1362   MS_LOG(INFO) << "q name: " << q->fullname_with_scope() << ", k name: " << k->fullname_with_scope()
1363                << ", v name: " << v->fullname_with_scope() << ", pse name " << pse->fullname_with_scope();
1364   if (num_heads < 0 || scale_value < 0 || next_token < 0 || num_key_value_heads < 0) {
1365     MS_LOG(WARNING) << "shape is invalid";
1366     return nullptr;
1367   }
1368   if (fa_parm == nullptr) {
1369     MS_LOG(WARNING) << "FA parameter is null, please check";
1370     return nullptr;
1371   }
1372   // create op
1373   auto prompt_flash_attention_prim = std::make_shared<ops::PromptFlashAttention>();
1374   if (prompt_flash_attention_prim == nullptr) {
1375     MS_LOG(ERROR) << "new prompt flash attention prim failed.";
1376     return nullptr;
1377   }
1378   // add attr
1379   prompt_flash_attention_prim->AddAttr("num_heads", api::MakeValue(num_heads));
1380   prompt_flash_attention_prim->AddAttr("input_layout", api::MakeValue("BNSD"));
1381   prompt_flash_attention_prim->AddAttr("next_tokens", api::MakeValue(next_token));
1382   prompt_flash_attention_prim->AddAttr("scale_value", api::MakeValue(scale_value));
1383   prompt_flash_attention_prim->AddAttr("num_key_value_heads", api::MakeValue(num_key_value_heads));
1384   prompt_flash_attention_prim->AddAttr("inner_precise", api::MakeValue(fa_parm->inner_precise));
1385   prompt_flash_attention_prim->AddAttr("sparse_mode", api::MakeValue(fa_parm->sparse_mode));
1386 
1387   auto fa_prim_c = prompt_flash_attention_prim->GetPrim();
1388   if (fa_prim_c == nullptr) {
1389     MS_LOG(ERROR) << "fa_prim_c is nullptr.";
1390     return nullptr;
1391   }
1392   CNodePtr prompt_flash_attention_cnode = nullptr;
1393   auto none_value_node_1 = NewValueNode(std::make_shared<None>());
1394   none_value_node_1->set_abstract(std::make_shared<abstract::AbstractNone>());
1395   auto none_value_node_2 = NewValueNode(std::make_shared<None>());
1396   none_value_node_2->set_abstract(std::make_shared<abstract::AbstractNone>());
1397 
1398   if (atten_mask != nullptr) {
1399     prompt_flash_attention_cnode =
1400       func_graph->NewCNode(fa_prim_c, {q, k, v, atten_mask, none_value_node_1, none_value_node_2, pse});
1401   } else {
1402     auto none_value_node = NewValueNode(std::make_shared<None>());
1403     none_value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
1404     prompt_flash_attention_cnode =
1405       func_graph->NewCNode(fa_prim_c, {q, k, v, none_value_node, none_value_node_1, none_value_node_2, pse});
1406   }
1407   if (prompt_flash_attention_cnode == nullptr) {
1408     MS_LOG(ERROR) << "new prompt_flash_attention_cnode failed.";
1409     return nullptr;
1410   }
1411   prompt_flash_attention_cnode->set_fullname_with_scope(node->fullname_with_scope() +
1412                                                         "_prompt_flash_attention_bnsd_pse");
1413   if (node->abstract() != nullptr) {
1414     prompt_flash_attention_cnode->set_abstract(node->abstract()->Clone());
1415   }
1416   MS_LOG(INFO) << "create PromptFlashAttention success.";
1417   return prompt_flash_attention_cnode;
1418 }
1419 
CreatePromptFlashAttentionCnodeForBSH(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q,const AnfNodePtr & k,const AnfNodePtr & v,const AnfNodePtr & atten_mask,int64_t num_heads,int64_t next_token,float scale_value,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1420 CNodePtr FlashAttentionFusion::CreatePromptFlashAttentionCnodeForBSH(
1421   const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
1422   const AnfNodePtr &atten_mask, int64_t num_heads, int64_t next_token, float scale_value,
1423   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1424   MS_LOG(INFO) << "CreatePromptFlashAttentionCnodeForBSH";
1425   MS_LOG(INFO) << "input Q name: " << q->fullname_with_scope() << " ,input K name: " << k->fullname_with_scope()
1426                << " ,input V name: " << v->fullname_with_scope();
1427   // create op
1428   auto prompt_flash_attention_prim = std::make_shared<ops::PromptFlashAttention>();
1429   if (prompt_flash_attention_prim == nullptr) {
1430     MS_LOG(ERROR) << "incre_flash_attention_prim is nullptr.";
1431     return nullptr;
1432   }
1433   if (fa_parm == nullptr) {
1434     MS_LOG(WARNING) << "FA parameter is null, please check";
1435     return nullptr;
1436   }
1437   // add attr
1438   prompt_flash_attention_prim->AddAttr("num_heads", api::MakeValue(num_heads));
1439   prompt_flash_attention_prim->AddAttr("input_layout", api::MakeValue("BSH"));
1440   prompt_flash_attention_prim->AddAttr("next_tokens", api::MakeValue(next_token));
1441   prompt_flash_attention_prim->AddAttr("scale_value", api::MakeValue(scale_value));
1442   prompt_flash_attention_prim->AddAttr("num_key_value_heads", api::MakeValue(num_heads));
1443   prompt_flash_attention_prim->AddAttr("inner_precise", api::MakeValue(fa_parm->inner_precise));
1444   prompt_flash_attention_prim->AddAttr("sparse_mode", api::MakeValue(fa_parm->sparse_mode));
1445 
1446   MS_LOG(INFO) << "num heads: " << num_heads << ", input layout: BSH, next tokens: " << next_token
1447                << ", scale value: " << scale_value;
1448   auto fa_prim_c = prompt_flash_attention_prim->GetPrim();
1449   CNodePtr prompt_flash_attention_cnode = nullptr;
1450   if (atten_mask != nullptr) {
1451     prompt_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v, atten_mask});
1452   } else {
1453     prompt_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v});
1454   }
1455   if (prompt_flash_attention_cnode == nullptr) {
1456     MS_LOG(ERROR) << "new cnode failed.";
1457     return nullptr;
1458   }
1459   prompt_flash_attention_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_prompt_flash_attention_bsh");
1460   if (node->abstract() != nullptr) {
1461     prompt_flash_attention_cnode->set_abstract(node->abstract()->Clone());
1462   }
1463   MS_LOG(INFO) << "create PromptFlashAttention success.";
1464   return prompt_flash_attention_cnode;
1465 }
1466 
CreateFAForBNSDWithAttenMask(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const CNodePtr & qk_matmul,const CNodePtr & v_matmul,const CNodePtr & attention_mask_mul,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1467 CNodePtr FlashAttentionFusion::CreateFAForBNSDWithAttenMask(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
1468                                                             const CNodePtr &qk_matmul, const CNodePtr &v_matmul,
1469                                                             const CNodePtr &attention_mask_mul,
1470                                                             const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1471   auto q = qk_matmul->input(kNumIndex1);
1472   MS_CHECK_TRUE_RET(q != nullptr, nullptr);
1473   auto k = qk_matmul->input(kNumIndex2);
1474   MS_CHECK_TRUE_RET(k != nullptr, nullptr);
1475   auto v = v_matmul->input(kNumIndex2);
1476   MS_CHECK_TRUE_RET(v != nullptr, nullptr);
1477   auto atten_mask = attention_mask_mul->input(1)->cast<CNodePtr>();
1478   MS_CHECK_TRUE_RET(atten_mask != nullptr, nullptr);
1479 
1480   auto input_tensor_q_shape = GetTensorShape(qk_matmul, kNumIndex1);
1481   if (input_tensor_q_shape.size() != kNumDimSize4) {
1482     MS_LOG(ERROR) << "q shape is not 4 dims";
1483     return nullptr;
1484   }
1485   auto input_tensor_k_shape = GetTensorShape(qk_matmul, kNumIndex2);
1486   if (input_tensor_k_shape.size() != kNumDimSize4) {
1487     MS_LOG(ERROR) << "k shape is not 4 dims";
1488     return nullptr;
1489   }
1490   auto input_tensor_v_shape = GetTensorShape(v_matmul, kNumIndex2);
1491   if (input_tensor_v_shape.size() != kNumDimSize4) {
1492     MS_LOG(ERROR) << "v shape is not 4 dims";
1493     return nullptr;
1494   }
1495   auto atten_mask_input_shape = GetTensorShape(attention_mask_mul, 1);
1496   if (input_tensor_v_shape.size() != kNumDimSize4) {
1497     MS_LOG(ERROR) << "v shape is not 4 dims";
1498     return nullptr;
1499   }
1500   MS_LOG(INFO) << "q name: " << q->fullname_with_scope() << " , k name: " << k->fullname_with_scope()
1501                << " , v name: " << v->fullname_with_scope()
1502                << ", atten mask name: " << atten_mask->fullname_with_scope();
1503   MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << ", k shape: " << input_tensor_k_shape
1504                << ", v shape: " << input_tensor_v_shape << ", atten mask name: " << atten_mask_input_shape;
1505   // check input shape
1506   if (input_tensor_q_shape[kNumIndex3] <= 0 || input_tensor_q_shape[kNumIndex1] <= 0) {
1507     MS_LOG(ERROR) << "D is -1";
1508     return nullptr;
1509   }
1510   float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
1511   int64_t seq_len = input_tensor_q_shape[kNumIndex2];
1512   int64_t num_key_value_heads = input_tensor_k_shape[1];
1513   if (seq_len != 1) {
1514     return CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
1515                                                   input_tensor_q_shape[kNumIndex1], 0, scale_value, fa_parm,
1516                                                   num_key_value_heads);
1517   } else {
1518     MS_LOG(INFO) << "seq len is 1, incre flash attention.";
1519     return CreateIncreFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
1520                                                  input_tensor_q_shape[kNumIndex1], scale_value, num_key_value_heads);
1521   }
1522   return nullptr;
1523 }
1524 
CreateGQACNodeForBNSD(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const CNodePtr & qk_matmul,const CNodePtr & v_matmul,const CNodePtr & attention_mask_mul,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1525 CNodePtr FlashAttentionFusion::CreateGQACNodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
1526                                                      const CNodePtr &qk_matmul, const CNodePtr &v_matmul,
1527                                                      const CNodePtr &attention_mask_mul,
1528                                                      const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1529   auto q = qk_matmul->input(kNumIndex1);
1530   MS_CHECK_TRUE_RET(q != nullptr, nullptr);
1531 
1532   auto k_reshape = qk_matmul->input(kNumIndex2)->cast<CNodePtr>();
1533   MS_LOG(INFO) << k_reshape->fullname_with_scope();
1534   auto k_tile = k_reshape->input(kNumIndex1)->cast<CNodePtr>();
1535   MS_LOG(INFO) << k_tile->fullname_with_scope();
1536   auto k_expend_dim = k_tile->input(kNumIndex1)->cast<CNodePtr>();
1537 
1538   auto v_reshape = v_matmul->input(kNumIndex2)->cast<CNodePtr>();
1539   MS_LOG(INFO) << v_reshape->fullname_with_scope();
1540   auto v_tile = v_reshape->input(kNumIndex1)->cast<CNodePtr>();
1541   MS_LOG(INFO) << v_tile->fullname_with_scope();
1542   auto v_expend_dim = v_tile->input(kNumIndex1)->cast<CNodePtr>();
1543 
1544   auto k = k_expend_dim->input(kNumIndex1);
1545   MS_CHECK_TRUE_RET(k != nullptr, nullptr);
1546   auto v = v_expend_dim->input(kNumIndex1);
1547   MS_CHECK_TRUE_RET(v != nullptr, nullptr);
1548 
1549   auto atten_mask = attention_mask_mul->input(kNumIndex1)->cast<CNodePtr>();
1550   MS_CHECK_TRUE_RET(atten_mask != nullptr, nullptr);
1551 
1552   auto input_tensor_q_shape = GetTensorShape(qk_matmul, kNumIndex1);
1553   if (input_tensor_q_shape.size() != kNumDimSize4) {
1554     MS_LOG(ERROR) << "q shape is not 4 dims";
1555     return nullptr;
1556   }
1557   auto input_tensor_k_shape = GetTensorShape(k_expend_dim, kNumIndex1);
1558   if (input_tensor_k_shape.size() != kNumDimSize4) {
1559     MS_LOG(ERROR) << "k shape is not 4 dims";
1560     return nullptr;
1561   }
1562   auto input_tensor_v_shape = GetTensorShape(v_expend_dim, kNumIndex1);
1563   if (input_tensor_v_shape.size() != kNumDimSize4) {
1564     MS_LOG(ERROR) << "v shape is not 4 dims";
1565     return nullptr;
1566   }
1567   auto atten_mask_input_shape = GetTensorShape(attention_mask_mul, 1);
1568   if (input_tensor_v_shape.size() != kNumDimSize4) {
1569     MS_LOG(ERROR) << "v shape is not 4 dims";
1570     return nullptr;
1571   }
1572   MS_LOG(INFO) << "q name: " << q->fullname_with_scope() << " , k name: " << k->fullname_with_scope()
1573                << " , v name: " << v->fullname_with_scope()
1574                << ", atten mask name: " << atten_mask->fullname_with_scope();
1575   MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << ", k shape: " << input_tensor_k_shape
1576                << ", v shape: " << input_tensor_v_shape << ", atten mask shape: " << atten_mask_input_shape;
1577   // check input shape
1578   if (input_tensor_q_shape[kNumIndex3] <= 0 || input_tensor_q_shape[kNumIndex1] <= 0) {
1579     MS_LOG(ERROR) << "D is -1";
1580     return nullptr;
1581   }
1582   float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
1583   int64_t seq_len = input_tensor_q_shape[kNumIndex2];
1584   int64_t num_key_value_heads = input_tensor_k_shape[1];
1585   if (seq_len != 1) {
1586     return CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
1587                                                   input_tensor_q_shape[kNumIndex1], 0, scale_value, fa_parm,
1588                                                   num_key_value_heads);
1589   } else {
1590     MS_LOG(INFO) << "seq len is 1, incre flash attention.";
1591     return CreateIncreFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
1592                                                  input_tensor_q_shape[kNumIndex1], scale_value, num_key_value_heads);
1593   }
1594   return nullptr;
1595 }
1596 
CreateIncreFlashAttentionCnodeForBNSD(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q,const AnfNodePtr & k,const AnfNodePtr & v,const AnfNodePtr & atten_mask,int64_t num_heads,float scale_value,int64_t num_key_value_heads) const1597 CNodePtr FlashAttentionFusion::CreateIncreFlashAttentionCnodeForBNSD(
1598   const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
1599   const AnfNodePtr &atten_mask, int64_t num_heads, float scale_value, int64_t num_key_value_heads) const {
1600   MS_LOG(INFO) << "CreateIncreFlashAttentionCnodeForBNSD";
1601   // create op
1602   auto incre_flash_attention_prim = std::make_shared<ops::IncreFlashAttention>();
1603   if (incre_flash_attention_prim == nullptr) {
1604     MS_LOG(ERROR) << "incre_flash_attention_prim is nullptr.";
1605     return nullptr;
1606   }
1607   // add attr
1608   incre_flash_attention_prim->AddAttr("num_heads", api::MakeValue(num_heads));
1609   incre_flash_attention_prim->AddAttr("input_layout", api::MakeValue("BNSD"));
1610   incre_flash_attention_prim->AddAttr("scale_value", api::MakeValue(scale_value));
1611   incre_flash_attention_prim->AddAttr("num_key_value_heads", api::MakeValue(num_key_value_heads));
1612 
1613   std::vector<int64_t> dyn_input_sizes = {-1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1};
1614   incre_flash_attention_prim->AddAttr("dyn_input_sizes", api::MakeValue(dyn_input_sizes));
1615 
1616   MS_LOG(INFO) << "num heads: " << num_heads << ", input layout: BNSD, scale value: " << scale_value
1617                << ", num_key_value_heads: " << num_key_value_heads << ", dyn_input_sizes:" << dyn_input_sizes;
1618   auto fa_prim_c = incre_flash_attention_prim->GetPrim();
1619   CNodePtr incre_flash_attention_cnode = nullptr;
1620   if (atten_mask != nullptr) {
1621     incre_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v, atten_mask});
1622   } else {
1623     incre_flash_attention_cnode = func_graph->NewCNode(fa_prim_c, {q, k, v});
1624   }
1625   if (incre_flash_attention_cnode == nullptr) {
1626     MS_LOG(ERROR) << "new cnode failed.";
1627     return nullptr;
1628   }
1629   incre_flash_attention_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_incre_flash_attention");
1630   if (node->abstract() != nullptr) {
1631     incre_flash_attention_cnode->set_abstract(node->abstract()->Clone());
1632   }
1633   MS_LOG(INFO) << "create IncreFlashAttention success.";
1634   return incre_flash_attention_cnode;
1635 }
1636 
CreateFAForSD15(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q_trans,const AnfNodePtr & k_trans,const AnfNodePtr & v_trans,int64_t num_head,int64_t next_token,float scale_value,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1637 CNodePtr FlashAttentionFusion::CreateFAForSD15(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
1638                                                const AnfNodePtr &q_trans, const AnfNodePtr &k_trans,
1639                                                const AnfNodePtr &v_trans, int64_t num_head, int64_t next_token,
1640                                                float scale_value,
1641                                                const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1642   MS_LOG(INFO) << "create flash attention for stable diffusion V1.5.";
1643   auto q_pad_node = CreatePadCNode(func_graph, q_trans, kNumPadSize, node->fullname_with_scope());
1644   if (q_pad_node == nullptr) {
1645     MS_LOG(WARNING) << "create q_pad_node failed.";
1646     return nullptr;
1647   }
1648   auto k_pad_node = CreatePadCNode(func_graph, k_trans, kNumPadSize, node->fullname_with_scope());
1649   if (k_pad_node == nullptr) {
1650     MS_LOG(WARNING) << "create q_pad_node failed.";
1651     return nullptr;
1652   }
1653   auto v_pad_node = CreatePadCNode(func_graph, v_trans, kNumPadSize, node->fullname_with_scope());
1654   if (v_pad_node == nullptr) {
1655     MS_LOG(WARNING) << "create q_pad_node failed.";
1656     return nullptr;
1657   }
1658   auto fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_pad_node, k_pad_node, v_pad_node, nullptr,
1659                                                         num_head, next_token, scale_value, fa_parm, num_head);
1660   if (fa_node == nullptr) {
1661     MS_LOG(WARNING) << "create fa_node failed.";
1662     return nullptr;
1663   }
1664   auto slice_node = CreateSliceCNode(func_graph, fa_node, kNumDValue);
1665   return slice_node;
1666 }
1667 
CreateFAWithPadAndPse(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const AnfNodePtr & q_trans,const AnfNodePtr & k_trans,const AnfNodePtr & v_trans,const AnfNodePtr & pse,int64_t num_head,int64_t next_token,float scale_value,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1668 CNodePtr FlashAttentionFusion::CreateFAWithPadAndPse(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
1669                                                      const AnfNodePtr &q_trans, const AnfNodePtr &k_trans,
1670                                                      const AnfNodePtr &v_trans, const AnfNodePtr &pse, int64_t num_head,
1671                                                      int64_t next_token, float scale_value,
1672                                                      const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1673   MS_LOG(INFO) << "create flash attention for stable diffusion with pse input.";
1674   auto q_pad_node = CreatePadCNode(func_graph, q_trans, kNumPadSize);
1675   if (q_pad_node == nullptr) {
1676     MS_LOG(WARNING) << "create q_pad_node failed.";
1677     return nullptr;
1678   }
1679   auto k_pad_node = CreatePadCNode(func_graph, k_trans, kNumPadSize);
1680   if (k_pad_node == nullptr) {
1681     MS_LOG(WARNING) << "create q_pad_node failed.";
1682     return nullptr;
1683   }
1684   auto v_pad_node = CreatePadCNode(func_graph, v_trans, kNumPadSize);
1685   if (v_pad_node == nullptr) {
1686     MS_LOG(WARNING) << "create q_pad_node failed.";
1687     return nullptr;
1688   }
1689   auto fa_node =
1690     CreatePromptFlashAttentionCnodeForBNSDWithPse(func_graph, node, q_pad_node, k_pad_node, v_pad_node, nullptr, pse,
1691                                                   num_head, next_token, scale_value, fa_parm, num_head);
1692   if (fa_node == nullptr) {
1693     MS_LOG(WARNING) << "create fa_node failed.";
1694     return nullptr;
1695   }
1696 
1697   auto slice_node = CreateSliceCNode(func_graph, fa_node, kNumDValue);
1698   return slice_node;
1699 }
1700 
GetScaleValueForDynamicShape(const AnfNodePtr & mul_const_input) const1701 float FlashAttentionFusion::GetScaleValueForDynamicShape(const AnfNodePtr &mul_const_input) const {
1702   tensor::TensorPtr tensor_info = nullptr;
1703   if (utils::isa<ValueNodePtr>(mul_const_input)) {
1704     auto value_node = mul_const_input->cast<ValueNodePtr>();
1705     if (value_node == nullptr) {
1706       MS_LOG(WARNING) << "value_node is nullptr.";
1707       return -1;
1708     }
1709     auto value = value_node->value();
1710     if (value == nullptr) {
1711       MS_LOG(WARNING) << "value is nullptr.";
1712       return -1;
1713     }
1714     tensor_info = value->cast<tensor::TensorPtr>();
1715   } else if (utils::isa<ParameterPtr>(mul_const_input)) {
1716     // for dynamic shape: get scale value
1717     auto mul_param = mul_const_input->cast<ParameterPtr>()->default_param();
1718     if (mul_param == nullptr) {
1719       MS_LOG(WARNING) << "mul_param is nullptr.";
1720       return -1;
1721     }
1722     tensor_info = mul_param->cast<tensor::TensorPtr>();
1723   } else {
1724     MS_LOG(WARNING) << "mul input is not ParameterPtr or ValueNodePtr.";
1725     return -1;
1726   }
1727   if (tensor_info == nullptr) {
1728     MS_LOG(WARNING) << "tensor info is nullptr.";
1729     return -1;
1730   }
1731   if (tensor_info->data_c() == nullptr) {
1732     MS_LOG(WARNING) << "mul data is nullptr.";
1733     return -1;
1734   }
1735   if (tensor_info->ElementsNum() != 1) {
1736     MS_LOG(WARNING) << "mul value elements num is not 1, ElementsNum is: " << tensor_info->ElementsNum();
1737     return -1;
1738   }
1739   if (tensor_info->data_type() == kNumberTypeFloat32) {
1740     return static_cast<float *>(tensor_info->data_c())[0];
1741   } else if (tensor_info->data_type() == kNumberTypeFloat16) {
1742     return static_cast<float>(static_cast<float16 *>(tensor_info->data_c())[0]);
1743   } else {
1744     MS_LOG(ERROR) << "bot support data type, " << tensor_info->data_type();
1745     return -1;
1746   }
1747   return -1;
1748 }
1749 
CreateFlashAttentionNodeForMsSDXL(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1750 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForMsSDXL(
1751   const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
1752   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1753   MS_LOG(INFO) << "flash attention for SDXL";
1754   MS_CHECK_TRUE_RET(node != nullptr, nullptr);
1755   auto matmul_2 = node->cast<CNodePtr>();
1756   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
1757   auto softmax = matmul_2->input(1)->cast<CNodePtr>();
1758   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
1759   auto div = softmax->input(1)->cast<CNodePtr>();
1760   MS_CHECK_TRUE_RET(div != nullptr, nullptr);
1761   auto matmul_1 = div->input(1)->cast<CNodePtr>();
1762   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
1763 
1764   auto q = matmul_1->input(1)->cast<CNodePtr>();
1765   MS_CHECK_TRUE_RET(q != nullptr, nullptr);
1766   auto k_trans = matmul_1->input(2)->cast<CNodePtr>();
1767   MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
1768   auto k = k_trans->input(1)->cast<CNodePtr>();
1769   MS_CHECK_TRUE_RET(k != nullptr, nullptr);
1770   auto v = matmul_2->input(2)->cast<CNodePtr>();
1771   MS_CHECK_TRUE_RET(v != nullptr, nullptr);
1772 
1773   auto input_tensor_q_shape = GetTensorShape(matmul_1, kNumIndex1);
1774   auto input_tensor_k_shape = GetTensorShape(k_trans, kNumIndex1);
1775   auto input_tensor_v_shape = GetTensorShape(matmul_2, kNumIndex2);
1776 
1777   MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
1778                << " , v shape: " << input_tensor_v_shape;
1779   if (input_tensor_q_shape.size() != kNumShapeSize4 || input_tensor_k_shape.size() != kNumShapeSize4 ||
1780       input_tensor_v_shape.size() != kNumShapeSize4) {
1781     MS_LOG(WARNING) << "input shape is not 4 dims";
1782     return nullptr;
1783   }
1784 
1785   int64_t next_tokens = kNumMaxNextTokenSize;
1786   float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
1787   int64_t num_head = input_tensor_q_shape[kNumIndex1];
1788 
1789   if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
1790                      fa_parm->seq_threshold)) {
1791     MS_LOG(INFO) << "shape check failed.";
1792     return nullptr;
1793   }
1794   auto fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, nullptr, num_head, next_tokens,
1795                                                         scale_value, fa_parm, num_head);
1796   MS_CHECK_TRUE_MSG(fa_node != nullptr, nullptr, "create FA failed, fa_node is nullptr.");
1797   auto manager = Manage(func_graph);
1798   (void)manager->Replace(matmul_2, fa_node);
1799   MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
1800   return nullptr;
1801 }
1802 
CreateFlashAttentionNodeForMsSDPseShift(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1803 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForMsSDPseShift(
1804   const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
1805   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1806   MS_LOG(INFO) << "flash attention for SD pse shift";
1807   MS_CHECK_TRUE_RET(node != nullptr, nullptr);
1808   // reshape
1809   auto reshape_output = node->cast<CNodePtr>();
1810   MS_CHECK_TRUE_RET(reshape_output != nullptr, nullptr);
1811   // matmul
1812   auto matmul_2 = reshape_output->input(kNumIndex1)->cast<CNodePtr>();
1813   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
1814   // cast
1815   auto cast = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
1816   MS_CHECK_TRUE_RET(cast != nullptr, nullptr);
1817   // softmax
1818   auto softmax = cast->input(kNumIndex1)->cast<CNodePtr>();
1819   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
1820   // add
1821   auto add = softmax->input(kNumIndex1)->cast<CNodePtr>();
1822   MS_CHECK_TRUE_RET(add != nullptr, nullptr);
1823   // mul
1824   auto mul = add->input(kNumIndex1)->cast<CNodePtr>();
1825   MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
1826   // matmul
1827   auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
1828   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
1829 
1830   // Q
1831   auto q_reshape = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
1832   MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
1833   auto q_input = q_reshape->input(kNumIndex1)->cast<CNodePtr>();
1834   MS_CHECK_TRUE_RET(q_input != nullptr, nullptr);
1835 
1836   // K
1837   auto trans = matmul_1->input(kNumIndex2)->cast<CNodePtr>();
1838   MS_CHECK_TRUE_RET(trans != nullptr, nullptr);
1839   auto k_reshape = trans->input(kNumIndex1)->cast<CNodePtr>();
1840   MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
1841   auto k_input = k_reshape->input(kNumIndex1)->cast<CNodePtr>();
1842   MS_CHECK_TRUE_RET(k_input != nullptr, nullptr);
1843 
1844   // V
1845   auto v_reshape = matmul_2->input(kNumIndex2)->cast<CNodePtr>();
1846   MS_CHECK_TRUE_RET(trans != nullptr, nullptr);
1847   auto v_input = v_reshape->input(kNumIndex1)->cast<CNodePtr>();
1848   MS_CHECK_TRUE_RET(v_input != nullptr, nullptr);
1849 
1850   auto input_tensor_q_shape = GetTensorShape(q_reshape, kNumIndex1);
1851   auto input_tensor_k_shape = GetTensorShape(k_reshape, kNumIndex1);
1852   auto input_tensor_v_shape = GetTensorShape(v_reshape, kNumIndex1);
1853 
1854   if (input_tensor_q_shape.size() != kNumShapeSize4 || input_tensor_k_shape.size() != kNumShapeSize4 ||
1855       input_tensor_v_shape.size() != kNumShapeSize4) {
1856     MS_LOG(WARNING) << "Dynamic shape is not supported, and need check Q or K input tensor shape: "
1857                     << input_tensor_q_shape;
1858     return nullptr;
1859   }
1860   if (input_tensor_q_shape.size() != kNumShapeSize4 || input_tensor_k_shape.size() != kNumShapeSize4 ||
1861       input_tensor_v_shape.size() != kNumShapeSize4) {
1862     MS_LOG(WARNING) << "Dynamic shape is not supported, and need check Q or K input tensor shape: "
1863                     << input_tensor_q_shape;
1864     return nullptr;
1865   }
1866 
1867   float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
1868   int64_t num_head = input_tensor_q_shape[kNumIndex1];
1869   int64_t next_tokens = kNumMaxNextTokenSize;
1870   int64_t d_value = input_tensor_q_shape[kNumIndex3];
1871 
1872   // reshape add input2 from (B*N)SS to BNSS
1873   auto manager = Manage(func_graph);
1874   MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
1875 
1876   std::vector<int32_t> BNSS_shape = {
1877     static_cast<int32_t>(input_tensor_q_shape[0]), static_cast<int32_t>(input_tensor_q_shape[1]),
1878     static_cast<int32_t>(input_tensor_q_shape[2]), static_cast<int32_t>(input_tensor_k_shape[2])};
1879   auto shape_parm_node = BuildIntVecParameterNode(func_graph, BNSS_shape, node->fullname_with_scope() + "_shape_perm");
1880   MS_CHECK_TRUE_MSG(shape_parm_node != nullptr, nullptr, "create shape_parm_node return nullptr");
1881 
1882   std::vector<AnfNodePtr> op_inputs;
1883   auto add_input2 = add->input(kNumIndex2);
1884   if (utils::isa<CNodePtr>(add_input2)) {
1885     auto pse_BN_S_S_node = add_input2->cast<CNodePtr>();
1886     op_inputs = {pse_BN_S_S_node, shape_parm_node};
1887   } else if (utils::isa<ParameterPtr>(add_input2)) {
1888     auto pse_BN_S_S_parm = add_input2->cast<ParameterPtr>();
1889     op_inputs = {pse_BN_S_S_parm, shape_parm_node};
1890   } else {
1891     return nullptr;
1892   }
1893 
1894   auto reshape_prim = std::make_shared<ops::Reshape>();
1895   MS_CHECK_TRUE_MSG(reshape_prim != nullptr, nullptr, "create reshape_prim return nullptr");
1896   auto reshape_prim_c = reshape_prim->GetPrim();
1897   MS_CHECK_TRUE_MSG(reshape_prim_c != nullptr, nullptr, "create prim_c return nullptr");
1898   auto pse_BNSS_node = func_graph->NewCNode(reshape_prim_c, op_inputs);
1899   MS_CHECK_TRUE_MSG(pse_BNSS_node != nullptr, nullptr, "create pse_BNSS_node return nullptr");
1900   pse_BNSS_node->set_fullname_with_scope(node->fullname_with_scope() + "_pse_reshape");
1901   if (node->abstract() != nullptr) {
1902     pse_BNSS_node->set_abstract(node->abstract()->Clone());
1903   }
1904 
1905   // FA fusion
1906   CNodePtr fa_node = nullptr;
1907   if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
1908                      fa_parm->seq_threshold)) {
1909     return nullptr;
1910   }
1911   if (d_value == kNumDValue) {
1912     fa_node = CreateFAWithPadAndPse(func_graph, node, q_input, k_input, v_input, pse_BNSS_node, num_head, next_tokens,
1913                                     scale_value, fa_parm);
1914   } else {
1915     fa_node =
1916       CreatePromptFlashAttentionCnodeForBNSDWithPse(func_graph, node, q_input, k_input, v_input, nullptr, pse_BNSS_node,
1917                                                     num_head, next_tokens, scale_value, fa_parm, num_head);
1918   }
1919   MS_CHECK_TRUE_RET(fa_node != nullptr, nullptr);
1920   (void)manager->Replace(node, fa_node);
1921   MS_LOG(INFO) << "create prompt flash attention success for stable diffusion pse shift.";
1922   return nullptr;
1923 }
1924 
CreateFlashAttentionNodeForMsSD21(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1925 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForMsSD21(
1926   const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
1927   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1928   MS_LOG(INFO) << "flash attention for SD21";
1929   MS_CHECK_TRUE_RET(node != nullptr, nullptr);
1930   auto matmul_2 = node->cast<CNodePtr>();
1931   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
1932   auto softmax = matmul_2->input(1)->cast<CNodePtr>();
1933   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
1934   auto mul = softmax->input(1)->cast<CNodePtr>();
1935   MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
1936   auto matmul_1 = mul->input(1)->cast<CNodePtr>();
1937   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
1938   auto transpose = matmul_1->input(2)->cast<CNodePtr>();
1939   MS_CHECK_TRUE_RET(transpose != nullptr, nullptr);
1940 
1941   auto q_reshape = matmul_1->input(1)->cast<CNodePtr>();
1942   MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
1943   auto q_trans = q_reshape->input(1)->cast<CNodePtr>();
1944   MS_CHECK_TRUE_RET(q_trans != nullptr, nullptr);
1945 
1946   auto k_reshape = transpose->input(1)->cast<CNodePtr>();
1947   MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
1948   auto k_trans = k_reshape->input(1)->cast<CNodePtr>();
1949   MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
1950 
1951   auto v_reshape = matmul_2->input(2)->cast<CNodePtr>();
1952   MS_CHECK_TRUE_RET(v_reshape != nullptr, nullptr);
1953   auto v_trans = v_reshape->input(1)->cast<CNodePtr>();
1954   MS_CHECK_TRUE_RET(v_trans != nullptr, nullptr);
1955 
1956   auto input_tensor_q_shape = GetTensorShape(q_reshape, 1);
1957   auto input_tensor_k_shape = GetTensorShape(k_reshape, 1);
1958   auto input_tensor_v_shape = GetTensorShape(v_reshape, 1);
1959 
1960   MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
1961                << " , v shape: " << input_tensor_v_shape;
1962   if (input_tensor_q_shape.size() != kNumShapeSize4 || input_tensor_k_shape.size() != kNumShapeSize4 ||
1963       input_tensor_v_shape.size() != kNumShapeSize4) {
1964     MS_LOG(WARNING) << "input shape is not 4 dims";
1965     return nullptr;
1966   }
1967 
1968   int64_t next_tokens = kNumMaxNextTokenSize;
1969   float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
1970   int64_t num_head = input_tensor_q_shape[kNumIndex1];
1971 
1972   if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
1973                      fa_parm->seq_threshold)) {
1974     MS_LOG(INFO) << "shape check failed.";
1975     return nullptr;
1976   }
1977 
1978   auto fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
1979                                                         next_tokens, scale_value, fa_parm, num_head);
1980   MS_CHECK_TRUE_MSG(fa_node != nullptr, nullptr, "create FA failed, fa_node is nullptr.");
1981   auto manager = Manage(func_graph);
1982   (void)manager->Replace(matmul_2, fa_node);
1983   MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
1984   return nullptr;
1985 }
1986 
CreateFlashAttentionNodeForVideoComposer(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const1987 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForVideoComposer(
1988   const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
1989   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
1990   MS_LOG(INFO) << "flash attention for wanxin";
1991   MS_CHECK_TRUE_RET(node != nullptr, nullptr);
1992   auto reshape = node->cast<CNodePtr>();
1993   MS_CHECK_TRUE_RET(reshape != nullptr, nullptr);
1994   auto matmul_2 = reshape->input(1)->cast<CNodePtr>();
1995   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
1996   auto cast_2 = matmul_2->input(1)->cast<CNodePtr>();
1997   MS_CHECK_TRUE_RET(cast_2 != nullptr, nullptr);
1998   auto softmax = cast_2->input(1)->cast<CNodePtr>();
1999   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2000   auto cast_1 = softmax->input(1)->cast<CNodePtr>();
2001   MS_CHECK_TRUE_RET(cast_1 != nullptr, nullptr);
2002   auto mul = cast_1->input(1)->cast<CNodePtr>();
2003   MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2004   auto matmul_1 = mul->input(1)->cast<CNodePtr>();
2005   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2006 
2007   auto q_reshape = matmul_1->input(1)->cast<CNodePtr>();
2008   MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
2009   auto q_trans = q_reshape->input(1)->cast<CNodePtr>();
2010   MS_CHECK_TRUE_RET(q_trans != nullptr, nullptr);
2011 
2012   auto k_trans_2 = matmul_1->input(2)->cast<CNodePtr>();
2013   MS_CHECK_TRUE_RET(k_trans_2 != nullptr, nullptr);
2014   auto k_reshape = k_trans_2->input(1)->cast<CNodePtr>();
2015   MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
2016   auto k_trans = k_reshape->input(1)->cast<CNodePtr>();
2017   MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
2018 
2019   auto v_reshape = matmul_2->input(2)->cast<CNodePtr>();
2020   MS_CHECK_TRUE_RET(v_reshape != nullptr, nullptr);
2021   auto v_trans = v_reshape->input(1)->cast<CNodePtr>();
2022   MS_CHECK_TRUE_RET(v_trans != nullptr, nullptr);
2023 
2024   auto input_tensor_q_shape = GetTensorShape(q_reshape, 1);
2025   auto input_tensor_k_shape = GetTensorShape(k_reshape, 1);
2026   auto input_tensor_v_shape = GetTensorShape(v_reshape, 1);
2027 
2028   MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
2029                << " , v shape: " << input_tensor_v_shape;
2030 
2031   float scale_value = 0;
2032   int64_t num_head = 0;
2033   int64_t next_tokens = kNumMaxNextTokenSize;
2034   int64_t d_value = 0;
2035   auto mul_const_input = mul->input(kNumIndex2);
2036 
2037   if (input_tensor_q_shape.size() != kNumShapeSize4) {
2038     scale_value = GetScaleValueForDynamicShape(mul_const_input);
2039     d_value = 1 / pow(scale_value, kNumPowerTwo);
2040     MS_LOG(INFO) << "d_value: " << d_value;
2041     // process bnsd shape
2042     MS_LOG(INFO) << "get flash attention param for dynamic shape, scale value is " << scale_value;
2043     std::vector<int32_t> new_shape = {0, 0, -1};
2044     auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
2045     auto output_shape_node = node->cast<CNodePtr>();
2046     output_shape_node->set_input(kNumIndex2, shape_node);
2047     auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
2048     num_head = GetNumHeadForSD(q_trans_reshape);
2049   } else if (input_tensor_q_shape.size() == kNumShapeSize4) {
2050     MS_LOG(INFO) << "get flash attention param for static shape.";
2051     // for static shape: get scale value
2052     scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
2053     num_head = input_tensor_q_shape[kNumIndex1];
2054     d_value = input_tensor_q_shape[kNumIndex3];
2055   } else {
2056     MS_LOG(WARNING) << "need check Q input tensor shape: " << input_tensor_q_shape;
2057     return nullptr;
2058   }
2059   CNodePtr fa_node = nullptr;
2060   if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
2061                      fa_parm->seq_threshold)) {
2062     return nullptr;
2063   }
2064   if (d_value == kNumDValue) {
2065     fa_node = CreateFAForSD15(func_graph, node, q_trans, k_trans, v_trans, num_head, next_tokens, scale_value, fa_parm);
2066   } else {
2067     fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
2068                                                      next_tokens, scale_value, fa_parm, num_head);
2069   }
2070   if (fa_node == nullptr) {
2071     return nullptr;
2072   }
2073   auto manager = Manage(func_graph);
2074   (void)manager->Replace(matmul_2, fa_node);
2075   MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
2076   return nullptr;
2077 }
2078 
CreateFlashAttentionNodeForSD(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2079 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSD(const std::string &pattern_name,
2080                                                              const FuncGraphPtr &func_graph, const AnfNodePtr &node,
2081                                                              const EquivPtr &equiv,
2082                                                              const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2083   auto cnode = node->cast<CNodePtr>();
2084   auto reshape_o2 = cnode;
2085   MS_CHECK_TRUE_RET(reshape_o2 != nullptr, nullptr);
2086   auto output_trans = reshape_o2->input(kNumIndex1)->cast<CNodePtr>();
2087   MS_CHECK_TRUE_RET(output_trans != nullptr, nullptr);
2088   cnode = output_trans->input(kNumIndex1)->cast<CNodePtr>();  // reshape
2089   MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
2090   auto matmul_2 = cnode->input(1)->cast<CNodePtr>();
2091   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2092   auto cast_2 = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2093   MS_CHECK_TRUE_RET(cast_2 != nullptr, nullptr);
2094   auto softmax = cast_2->input(kNumIndex1)->cast<CNodePtr>();
2095   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2096   auto mul = softmax->input(kNumIndex1)->cast<CNodePtr>();
2097   MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2098   auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2099   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2100   auto transpose = matmul_1->input(kNumIndex2)->cast<CNodePtr>();
2101   MS_CHECK_TRUE_RET(transpose != nullptr, nullptr);
2102   auto q_reshape = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
2103   MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
2104   auto k_reshape = transpose->input(kNumIndex1)->cast<CNodePtr>();
2105   MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
2106   auto v_reshape = matmul_2->input(kNumIndex2)->cast<CNodePtr>();
2107   MS_CHECK_TRUE_RET(v_reshape != nullptr, nullptr);
2108 
2109   auto q_trans = q_reshape->input(kNumIndex1);
2110   MS_CHECK_TRUE_RET(q_trans != nullptr, nullptr);
2111   auto k_trans = k_reshape->input(kNumIndex1);
2112   MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
2113   auto v_trans = v_reshape->input(kNumIndex1);
2114   MS_CHECK_TRUE_RET(v_trans != nullptr, nullptr);
2115 
2116   auto input_tensor_q_shape = GetTensorShape(q_reshape, kNumIndex1);
2117   auto input_tensor_k_shape = GetTensorShape(k_reshape, kNumIndex1);
2118   auto input_tensor_v_shape = GetTensorShape(v_reshape, kNumIndex1);
2119   MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
2120                << " , v shape: " << input_tensor_v_shape;
2121 
2122   float scale_value = 0;
2123   int64_t num_head = 0;
2124   int64_t next_tokens = kNumMaxNextTokenSize;
2125   int64_t d_value = 0;
2126   auto mul_const_input = mul->input(kNumIndex2);
2127 
2128   if (input_tensor_q_shape.size() != kNumShapeSize4) {
2129     scale_value = GetScaleValueForDynamicShape(mul_const_input);
2130     d_value = 1 / pow(scale_value, kNumPowerTwo);
2131     // process bnsd shape
2132     MS_LOG(INFO) << "get flash attention param for dynamic shape, scale value is " << scale_value;
2133     std::vector<int32_t> new_shape = {0, 0, -1};
2134     auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
2135     auto output_shape_node = node->cast<CNodePtr>();
2136     output_shape_node->set_input(kNumIndex2, shape_node);
2137     auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
2138     num_head = GetNumHeadForSD(q_trans_reshape);
2139   } else if (input_tensor_q_shape.size() == kNumShapeSize4) {
2140     MS_LOG(INFO) << "get flash attention param for static shape.";
2141     // for static shape: get scale value
2142     scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
2143     num_head = input_tensor_q_shape[kNumIndex1];
2144     d_value = input_tensor_q_shape[kNumIndex3];
2145   } else {
2146     MS_LOG(WARNING) << "need check Q input tensor shape: " << input_tensor_q_shape;
2147     return nullptr;
2148   }
2149   CNodePtr fa_node = nullptr;
2150   if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
2151                      fa_parm->seq_threshold)) {
2152     return nullptr;
2153   }
2154   if (d_value == kNumDValue) {
2155     fa_node = CreateFAForSD15(func_graph, node, q_trans, k_trans, v_trans, num_head, next_tokens, scale_value, fa_parm);
2156   } else {
2157     fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
2158                                                      next_tokens, scale_value, fa_parm, num_head);
2159   }
2160   if (fa_node == nullptr) {
2161     return nullptr;
2162   }
2163   auto manager = Manage(func_graph);
2164   (void)manager->Replace(cnode, fa_node);
2165   MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
2166   return nullptr;
2167 }
2168 
CreateFlashAttentionNodeForSDPreMul(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2169 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDPreMul(
2170   const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2171   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2172   auto output_reshape = node->cast<CNodePtr>();
2173   MS_CHECK_TRUE_RET(output_reshape != nullptr, nullptr);
2174   auto output_trans = output_reshape->input(kNumIndex1)->cast<CNodePtr>();
2175   MS_CHECK_TRUE_RET(output_trans != nullptr, nullptr);
2176   auto matmul_2 = output_trans->input(kNumIndex1)->cast<CNodePtr>();
2177   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2178   auto softmax = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2179   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2180   auto matmul_1 = softmax->input(kNumIndex1)->cast<CNodePtr>();
2181   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2182   auto mul_q = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
2183   MS_CHECK_TRUE_RET(mul_q != nullptr, nullptr);
2184   auto q_trans_BNSD = mul_q->input(kNumIndex1)->cast<CNodePtr>();
2185   MS_CHECK_TRUE_RET(q_trans_BNSD != nullptr, nullptr);
2186   auto mul_k = matmul_1->input(kNumIndex2)->cast<CNodePtr>();
2187   MS_CHECK_TRUE_RET(mul_k != nullptr, nullptr);
2188   auto k_trans_BNDS = mul_k->input(kNumIndex1)->cast<CNodePtr>();
2189   MS_CHECK_TRUE_RET(k_trans_BNDS != nullptr, nullptr);
2190   auto v_trans_BNSD = matmul_2->input(kNumIndex2)->cast<CNodePtr>();
2191   MS_CHECK_TRUE_RET(v_trans_BNSD != nullptr, nullptr);
2192   auto input_tensor_q_shape = GetTensorShape(mul_q, kNumIndex1);
2193   auto input_tensor_k_shape = GetTensorShape(mul_k, kNumIndex1);
2194   auto input_tensor_v_shape = GetTensorShape(matmul_2, kNumIndex2);
2195   MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
2196                << " , v shape: " << input_tensor_v_shape << ". q name: " << q_trans_BNSD->fullname_with_scope()
2197                << ", k name: " << k_trans_BNDS->fullname_with_scope()
2198                << ", v name: " << v_trans_BNSD->fullname_with_scope();
2199 
2200   float scale_value = 0;
2201   int64_t num_head = 0;
2202   int64_t next_tokens = kNumMaxNextTokenSize;
2203   int64_t d_value = 0;
2204   if (input_tensor_q_shape.size() == 0 && input_tensor_k_shape.size() == 0 && input_tensor_v_shape.size() == 0) {
2205     MS_CHECK_TRUE_RET(q_trans_BNSD->inputs().size() > kNumIndex1, nullptr);
2206     auto q_reshape = q_trans_BNSD->input(kNumIndex1)->cast<CNodePtr>();
2207     MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
2208     num_head = GetReshapeParam(q_reshape, kNumIndex2);
2209     d_value = GetReshapeParam(q_reshape, kNumIndex3);
2210     MS_LOG(INFO) << "num_head: " << num_head << ", d_value: " << d_value;
2211   } else if (input_tensor_q_shape.size() != kNumShapeSize4 || input_tensor_k_shape.size() != kNumShapeSize4) {
2212     auto pd2_q_conv = PD2DecoderPattern(q_trans_BNSD);
2213     if (IpAdapterPattern(q_trans_BNSD, k_trans_BNDS)) {
2214       if (!GetParamForIpAdapterPattern(q_trans_BNSD, k_trans_BNDS, &num_head, &d_value)) {
2215         MS_LOG(INFO) << "Get parameter for IpAdapterPattern failed";
2216         return nullptr;
2217       }
2218       std::vector<int32_t> new_shape = {0, 0, -1};
2219       auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
2220       MS_CHECK_TRUE_RET(shape_node != nullptr, nullptr);
2221       output_reshape->set_input(kNumIndex2, shape_node);
2222     } else if (pd2_q_conv != nullptr) {
2223       MS_LOG(INFO) << "Dynamic shape pattern: PD2DecoderPattern";
2224       auto pd2_q_conv_input2_shape = GetTensorShape(pd2_q_conv, kNumIndex2);
2225       d_value = pd2_q_conv_input2_shape[0];
2226       num_head = GetNumHeadForSD(q_trans_BNSD->input(kNumIndex1)->cast<CNodePtr>());
2227     } else {
2228       MS_LOG(INFO) << "Dynamic shape is not supported. Can not fusion FA.";
2229       return nullptr;
2230     }
2231   } else {
2232     MS_LOG(INFO) << "get flash attention param for static shape.";
2233     num_head = input_tensor_q_shape[kNumIndex1];
2234     d_value = input_tensor_q_shape[kNumIndex3];
2235     std::swap(input_tensor_k_shape[kNumIndex2], input_tensor_k_shape[kNumIndex3]);
2236   }
2237   scale_value = 1 / (pow(d_value, kNumPowerHalf));
2238   CNodePtr fa_node = nullptr;
2239   if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
2240                      fa_parm->seq_threshold)) {
2241     return nullptr;
2242   }
2243   if (d_value == kNumDValue) {
2244     fa_node = CreateFAForSD15(func_graph, node, q_trans_BNSD, k_trans_BNDS, v_trans_BNSD, num_head, next_tokens,
2245                               scale_value, fa_parm);
2246   } else {
2247     fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans_BNSD, k_trans_BNDS, v_trans_BNSD,
2248                                                      nullptr, num_head, next_tokens, scale_value, fa_parm, num_head);
2249   }
2250   MS_CHECK_TRUE_RET(fa_node != nullptr, nullptr);
2251   std::vector<int32_t> new_perm = {kNumIndex0, kNumIndex2, kNumIndex1, kNumIndex3};
2252   auto perm_node = BuildIntVecParameterNode(func_graph, new_perm, k_trans_BNDS->fullname_with_scope() + "_new_perm");
2253   MS_CHECK_TRUE_RET(perm_node != nullptr, nullptr);
2254   k_trans_BNDS->cast<CNodePtr>()->set_input(kNumIndex2, perm_node);
2255   auto manager = Manage(func_graph);
2256   MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
2257   (void)manager->Replace(matmul_2, fa_node);
2258   MS_LOG(INFO) << "create prompt flash attention success for pre-mul pattern.";
2259   return nullptr;
2260 }
2261 
CreateFlashAttentionNodeForSDWithoutCast(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2262 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDWithoutCast(
2263   const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2264   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2265   MS_LOG(INFO) << "format_bsh: " << fa_parm->format_bsh << ", seq_threshold: " << fa_parm->seq_threshold
2266                << ", inner_precise: " << fa_parm->inner_precise;
2267   auto cnode = node->cast<CNodePtr>();
2268   auto reshape_o2 = cnode;
2269   MS_CHECK_TRUE_RET(reshape_o2 != nullptr, nullptr);
2270   auto output_trans = reshape_o2->input(kNumIndex1)->cast<CNodePtr>();
2271   MS_CHECK_TRUE_RET(output_trans != nullptr, nullptr);
2272   cnode = output_trans->input(kNumIndex1)->cast<CNodePtr>();  // reshape
2273   MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
2274   auto matmul_2 = cnode->input(1)->cast<CNodePtr>();
2275   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2276   auto softmax = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2277   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2278   auto mul = softmax->input(kNumIndex1)->cast<CNodePtr>();
2279   MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2280   auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2281   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2282   auto transpose = matmul_1->input(kNumIndex2)->cast<CNodePtr>();
2283   MS_CHECK_TRUE_RET(transpose != nullptr, nullptr);
2284   auto q_reshape = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
2285   MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
2286   auto k_reshape = transpose->input(kNumIndex1)->cast<CNodePtr>();
2287   MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
2288   auto v_reshape = matmul_2->input(kNumIndex2)->cast<CNodePtr>();
2289   MS_CHECK_TRUE_RET(v_reshape != nullptr, nullptr);
2290 
2291   auto q_trans = q_reshape->input(kNumIndex1);
2292   MS_CHECK_TRUE_RET(q_trans != nullptr, nullptr);
2293   auto k_trans = k_reshape->input(kNumIndex1);
2294   MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
2295   auto v_trans = v_reshape->input(kNumIndex1);
2296   MS_CHECK_TRUE_RET(v_trans != nullptr, nullptr);
2297 
2298   auto input_tensor_q_shape = GetTensorShape(q_reshape, kNumIndex1);
2299   auto input_tensor_k_shape = GetTensorShape(k_reshape, kNumIndex1);
2300   auto input_tensor_v_shape = GetTensorShape(v_reshape, kNumIndex1);
2301   MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
2302                << " , v shape: " << input_tensor_v_shape;
2303   auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
2304   MS_CHECK_TRUE_RET(q_trans_reshape != nullptr, nullptr);
2305   auto k_trans_reshape = k_trans->cast<CNodePtr>()->input(kNumIndex1);
2306   MS_CHECK_TRUE_RET(k_trans_reshape != nullptr, nullptr);
2307   auto v_trans_reshape = v_trans->cast<CNodePtr>()->input(kNumIndex1);
2308   MS_CHECK_TRUE_RET(v_trans_reshape != nullptr, nullptr);
2309 
2310   auto top_matmul_q = q_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
2311   MS_CHECK_TRUE_RET(top_matmul_q != nullptr, nullptr);
2312   auto top_matmul_k = k_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
2313   MS_CHECK_TRUE_RET(top_matmul_k != nullptr, nullptr);
2314   auto top_matmul_v = v_trans_reshape->cast<CNodePtr>()->input(kNumIndex1);
2315   MS_CHECK_TRUE_RET(top_matmul_v != nullptr, nullptr);
2316 
2317   float scale_value = 0;
2318   int64_t num_head = 0;
2319   int64_t next_tokens = kNumMaxNextTokenSize;
2320   int64_t d_value = 0;
2321   auto mul_const_input = mul->input(kNumIndex2);
2322   bool actual_BSH = false;
2323 
2324   if (input_tensor_q_shape.size() != kNumShapeSize4) {
2325     scale_value = GetScaleValueForDynamicShape(mul_const_input);
2326     d_value = 1 / pow(scale_value, kNumPowerTwo);
2327     // process bnsd shape
2328     MS_LOG(INFO) << "get flash attention param for dynamic shape, scale value is " << scale_value;
2329     std::vector<int32_t> new_shape = {0, 0, -1};
2330     auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
2331     auto output_shape_node = node->cast<CNodePtr>();
2332     output_shape_node->set_input(kNumIndex2, shape_node);
2333     num_head = GetNumHeadForSD(q_trans_reshape);
2334   } else {
2335     MS_LOG(INFO) << "get flash attention param for static shape.";
2336     // for static shape: get scale value
2337     scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
2338     num_head = input_tensor_q_shape[kNumIndex1];
2339     d_value = input_tensor_q_shape[kNumIndex3];
2340   }
2341   CNodePtr fa_node = nullptr;
2342   if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
2343                      fa_parm->seq_threshold)) {
2344     MS_LOG(INFO) << "Can not pass shape check.";
2345     return nullptr;
2346   }
2347   if (d_value == kNumDValue) {
2348     fa_node = CreateFAForSD15(func_graph, node, q_trans, k_trans, v_trans, num_head, next_tokens, scale_value, fa_parm);
2349   } else if (fa_parm->format_bsh) {
2350     fa_node = CreatePromptFlashAttentionCnodeForBSH(func_graph, node, top_matmul_q, top_matmul_k, top_matmul_v, nullptr,
2351                                                     num_head, next_tokens, scale_value, fa_parm);
2352     actual_BSH = true;
2353   } else {
2354     fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
2355                                                      next_tokens, scale_value, fa_parm, num_head);
2356   }
2357   MS_CHECK_TRUE_RET(fa_node != nullptr, nullptr);
2358   auto manager = Manage(func_graph);
2359   MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
2360   if (actual_BSH) {
2361     (void)manager->Replace(node, fa_node);
2362   } else {
2363     (void)manager->Replace(cnode, fa_node);
2364   }
2365   MS_LOG(INFO) << "create prompt flash attention success for without cast, BSH: " << actual_BSH;
2366   return nullptr;
2367 }
2368 
CreateFlashAttentionNodeForPanGu(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2369 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForPanGu(
2370   const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2371   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2372   auto matmul_2 = node->cast<CNodePtr>();
2373   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2374   auto cast_2 = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2375   MS_CHECK_TRUE_RET(cast_2 != nullptr, nullptr);
2376   auto softmax = cast_2->input(kNumIndex1)->cast<CNodePtr>();
2377   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2378   auto add = softmax->input(kNumIndex1)->cast<CNodePtr>();
2379   MS_CHECK_TRUE_RET(add != nullptr, nullptr);
2380   auto atten_mask_mul = add->input(kNumIndex1)->cast<CNodePtr>();
2381   MS_CHECK_TRUE_RET(atten_mask_mul != nullptr, nullptr);
2382   auto cast_1 = add->input(kNumIndex2)->cast<CNodePtr>();
2383   MS_CHECK_TRUE_RET(cast_1 != nullptr, nullptr);
2384   auto matmul_1 = cast_1->input(kNumIndex1)->cast<CNodePtr>();
2385   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2386   auto div = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
2387   MS_CHECK_TRUE_RET(div != nullptr, nullptr);
2388 
2389   // PromptFlashAttention input tensor
2390   auto q = div->input(kNumIndex1);
2391   MS_CHECK_TRUE_RET(q != nullptr, nullptr);
2392   auto k = matmul_1->input(kNumIndex2);
2393   MS_CHECK_TRUE_RET(k != nullptr, nullptr);
2394   auto v = matmul_2->input(kNumIndex2);
2395   MS_CHECK_TRUE_RET(v != nullptr, nullptr);
2396   auto atten_mask = atten_mask_mul->input(kNumIndex1)->cast<CNodePtr>();
2397   MS_CHECK_TRUE_RET(atten_mask != nullptr, nullptr);
2398 
2399   auto input_tensor_q_shape = GetTensorShape(div, kNumIndex1);
2400   if (input_tensor_q_shape.size() != kNumDimSize4) {
2401     MS_LOG(ERROR) << "q shape is not 4 dims";
2402     return nullptr;
2403   }
2404   auto input_tensor_k_shape = GetTensorShape(matmul_1, kNumIndex2);
2405   if (input_tensor_k_shape.size() != kNumDimSize4) {
2406     MS_LOG(ERROR) << "k shape is not 4 dims";
2407     return nullptr;
2408   }
2409   auto input_tensor_v_shape = GetTensorShape(matmul_2, kNumIndex2);
2410   if (input_tensor_v_shape.size() != kNumDimSize4) {
2411     MS_LOG(ERROR) << "v shape is not 4 dims";
2412     return nullptr;
2413   }
2414   MS_LOG(INFO) << "q name: " << q->fullname_with_scope() << " , k name: " << k->fullname_with_scope()
2415                << " , v name: " << v->fullname_with_scope();
2416   MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << ", k shape: " << input_tensor_k_shape
2417                << ", v shape: " << input_tensor_v_shape;
2418 
2419   // check input shape
2420   if (input_tensor_q_shape[kNumIndex3] <= 0 || input_tensor_q_shape[kNumIndex1] <= 0) {
2421     MS_LOG(ERROR) << "D is -1";
2422     return nullptr;
2423   }
2424   float scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
2425   int64_t seq_len = input_tensor_q_shape[kNumIndex2];
2426   if (seq_len != 1) {
2427     return CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
2428                                                   input_tensor_q_shape[kNumIndex1], 0, scale_value, fa_parm,
2429                                                   input_tensor_k_shape[kNumIndex1]);
2430   } else {
2431     MS_LOG(INFO) << "seq len is 1, incre flash attention.";
2432     return CreateIncreFlashAttentionCnodeForBNSD(func_graph, node, q, k, v, atten_mask,
2433                                                  input_tensor_q_shape[kNumIndex1], scale_value,
2434                                                  input_tensor_q_shape[kNumIndex1]);
2435   }
2436   return nullptr;
2437 }
2438 
CreateFlashAttentionNodeForLLAMAPatternV1(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2439 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForLLAMAPatternV1(
2440   const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2441   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2442   auto matmul_2 = node->cast<CNodePtr>();
2443   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2444   auto cast_2 = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2445   MS_CHECK_TRUE_RET(cast_2 != nullptr, nullptr);
2446   auto softmax = cast_2->input(kNumIndex1)->cast<CNodePtr>();
2447   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2448   auto cast_1 = softmax->input(kNumIndex1)->cast<CNodePtr>();
2449   MS_CHECK_TRUE_RET(cast_1 != nullptr, nullptr);
2450   auto add = cast_1->input(kNumIndex1)->cast<CNodePtr>();
2451   MS_CHECK_TRUE_RET(add != nullptr, nullptr);
2452 
2453   auto attention_mask_mul = add->input(kNumIndex1)->cast<CNodePtr>();
2454   MS_CHECK_TRUE_RET(attention_mask_mul != nullptr, nullptr);
2455 
2456   auto mul = add->input(kNumIndex2)->cast<CNodePtr>();
2457   MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2458   auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2459   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2460 
2461   auto pfa_q_shape = GetTensorShape(matmul_1, kNumIndex1);
2462   auto pfa_k_shape = GetTensorShape(matmul_1, kNumIndex2);
2463   auto pfa_v_shape = GetTensorShape(matmul_2, kNumIndex2);
2464   MS_LOG(INFO) << "q shape: " << pfa_q_shape << ", k shape: " << pfa_k_shape << ", v shape: " << pfa_v_shape;
2465 
2466   // process for GQA
2467   if (IsGQAPattern(matmul_1, matmul_2)) {
2468     MS_LOG(INFO) << "create GQA node for bnsd.";
2469     return CreateGQACNodeForBNSD(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2470   }
2471   return CreateFAForBNSDWithAttenMask(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2472 }
2473 
CreateFlashAttentionNodeForLLAMAPatternV2(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2474 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForLLAMAPatternV2(
2475   const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2476   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2477   auto matmul_2 = node->cast<CNodePtr>();
2478   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2479   auto softmax = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2480   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2481   auto add = softmax->input(kNumIndex1)->cast<CNodePtr>();
2482   MS_CHECK_TRUE_RET(add != nullptr, nullptr);
2483 
2484   auto attention_mask_mul = add->input(kNumIndex1)->cast<CNodePtr>();
2485   MS_CHECK_TRUE_RET(attention_mask_mul != nullptr, nullptr);
2486 
2487   auto mul = add->input(kNumIndex2)->cast<CNodePtr>();
2488   MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2489   auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2490   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2491 
2492   auto pfa_q_shape = GetTensorShape(matmul_1, kNumIndex1);
2493   auto pfa_k_shape = GetTensorShape(matmul_1, kNumIndex2);
2494   auto pfa_v_shape = GetTensorShape(matmul_2, kNumIndex2);
2495   MS_LOG(INFO) << "q shape: " << pfa_q_shape << ", k shape: " << pfa_k_shape << ", v shape: " << pfa_v_shape;
2496 
2497   // process for GQA
2498   if (IsGQAPattern(matmul_1, matmul_2)) {
2499     MS_LOG(INFO) << "create GQA node for bnsd.";
2500     return CreateGQACNodeForBNSD(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2501   }
2502   return CreateFAForBNSDWithAttenMask(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2503 }
2504 
CreateFlashAttentionNodeForBaiChuanPattern(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2505 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForBaiChuanPattern(
2506   const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2507   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2508   auto matmul_2 = node->cast<CNodePtr>();
2509   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2510   auto softmax = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2511   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2512   auto add = softmax->input(kNumIndex1)->cast<CNodePtr>();
2513   MS_CHECK_TRUE_RET(add != nullptr, nullptr);
2514 
2515   auto attention_mask_mul = add->input(kNumIndex1)->cast<CNodePtr>();
2516   MS_CHECK_TRUE_RET(attention_mask_mul != nullptr, nullptr);
2517 
2518   auto add_up = add->input(kNumIndex2)->cast<CNodePtr>();
2519   auto mul = add_up->input(kNumIndex1)->cast<CNodePtr>();
2520   MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2521   auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2522   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2523   // process for GQA
2524   auto pfa_q_shape = GetTensorShape(matmul_1, kNumIndex1);
2525   auto pfa_k_shape = GetTensorShape(matmul_1, kNumIndex2);
2526   auto pfa_v_shape = GetTensorShape(matmul_2, kNumIndex2);
2527   MS_LOG(INFO) << "q shape: " << pfa_q_shape << ", k shape: " << pfa_k_shape << ", v shape: " << pfa_v_shape;
2528   if (IsGQAPattern(matmul_1, matmul_2)) {
2529     MS_LOG(INFO) << "create GQA node for BNSD.";
2530     return CreateGQACNodeForBNSD(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2531   }
2532   return CreateFAForBNSDWithAttenMask(func_graph, node, matmul_1, matmul_2, attention_mask_mul, fa_parm);
2533 }
2534 
CreateFlashAttentionNodeForSDEinsum(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv,const std::shared_ptr<FlashAttentionParm> & fa_parm) const2535 CNodePtr FlashAttentionFusion::CreateFlashAttentionNodeForSDEinsum(
2536   const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv,
2537   const std::shared_ptr<FlashAttentionParm> &fa_parm) const {
2538   auto cnode = node->cast<CNodePtr>();
2539   auto reshape_o2 = cnode;
2540   MS_CHECK_TRUE_RET(reshape_o2 != nullptr, nullptr);
2541   auto output_trans = reshape_o2->input(kNumIndex1)->cast<CNodePtr>();
2542   MS_CHECK_TRUE_RET(output_trans != nullptr, nullptr);
2543   cnode = output_trans->input(kNumIndex1)->cast<CNodePtr>();  // reshape
2544   MS_CHECK_TRUE_RET(cnode != nullptr, nullptr);
2545   auto matmul_2 = cnode->input(1)->cast<CNodePtr>();
2546   MS_CHECK_TRUE_RET(matmul_2 != nullptr, nullptr);
2547   auto softmax = matmul_2->input(kNumIndex1)->cast<CNodePtr>();
2548   MS_CHECK_TRUE_RET(softmax != nullptr, nullptr);
2549   auto mul = softmax->input(kNumIndex1)->cast<CNodePtr>();
2550   MS_CHECK_TRUE_RET(mul != nullptr, nullptr);
2551   auto matmul_1 = mul->input(kNumIndex1)->cast<CNodePtr>();
2552   MS_CHECK_TRUE_RET(matmul_1 != nullptr, nullptr);
2553   auto q_reshape = matmul_1->input(kNumIndex1)->cast<CNodePtr>();
2554   MS_CHECK_TRUE_RET(q_reshape != nullptr, nullptr);
2555   auto k_reshape = matmul_1->input(kNumIndex2)->cast<CNodePtr>();
2556   MS_CHECK_TRUE_RET(k_reshape != nullptr, nullptr);
2557   auto v_reshape = matmul_2->input(kNumIndex2)->cast<CNodePtr>();
2558   MS_CHECK_TRUE_RET(v_reshape != nullptr, nullptr);
2559 
2560   auto q_trans = q_reshape->input(kNumIndex1);
2561   MS_CHECK_TRUE_RET(q_trans != nullptr, nullptr);
2562   auto k_trans = k_reshape->input(kNumIndex1);
2563   MS_CHECK_TRUE_RET(k_trans != nullptr, nullptr);
2564   auto v_trans = v_reshape->input(kNumIndex1);
2565   MS_CHECK_TRUE_RET(v_trans != nullptr, nullptr);
2566 
2567   auto input_tensor_q_shape = GetTensorShape(q_reshape, kNumIndex1);
2568   auto input_tensor_k_shape = GetTensorShape(k_reshape, kNumIndex1);
2569   auto input_tensor_v_shape = GetTensorShape(v_reshape, kNumIndex1);
2570   MS_LOG(INFO) << "q shape: " << input_tensor_q_shape << " , k shape: " << input_tensor_k_shape
2571                << " , v shape: " << input_tensor_v_shape;
2572 
2573   float scale_value = 0;
2574   int64_t num_head = 0;
2575   int64_t next_tokens = kNumMaxNextTokenSize;
2576   int64_t d_value = 0;
2577   auto mul_const_input = mul->input(kNumIndex2);
2578 
2579   if (input_tensor_q_shape.size() != kNumShapeSize4) {
2580     scale_value = GetScaleValueForDynamicShape(mul_const_input);
2581     d_value = 1 / pow(scale_value, kNumPowerTwo);
2582     // process bnsd shape
2583     MS_LOG(INFO) << "get flash attention param for dynamic shape, scale value is " << scale_value;
2584     std::vector<int32_t> new_shape = {0, 0, -1};
2585     auto shape_node = BuildIntVecParameterNode(func_graph, new_shape, node->fullname_with_scope() + "_new_shape");
2586     auto output_shape_node = node->cast<CNodePtr>();
2587     output_shape_node->set_input(kNumIndex2, shape_node);
2588     auto q_trans_reshape = q_trans->cast<CNodePtr>()->input(kNumIndex1);
2589     num_head = GetNumHeadForSD(q_trans_reshape);
2590   } else if (input_tensor_q_shape.size() == kNumShapeSize4) {
2591     MS_LOG(INFO) << "get flash attention param for static shape.";
2592     // for static shape: get scale value
2593     scale_value = 1 / (pow(input_tensor_q_shape[kNumIndex3], kNumPowerHalf));
2594     num_head = input_tensor_q_shape[kNumIndex1];
2595     d_value = input_tensor_q_shape[kNumIndex3];
2596   } else {
2597     MS_LOG(WARNING) << "need check Q input tensor shape: " << input_tensor_q_shape;
2598     return nullptr;
2599   }
2600   CNodePtr fa_node = nullptr;
2601   if (!PFACheckShape(scale_value, input_tensor_q_shape, input_tensor_k_shape, input_tensor_v_shape,
2602                      fa_parm->seq_threshold)) {
2603     return nullptr;
2604   }
2605 
2606   if (d_value == kNumDValue) {
2607     fa_node = CreateFAForSD15(func_graph, node, q_trans, k_trans, v_trans, num_head, next_tokens, scale_value, fa_parm);
2608   } else {
2609     fa_node = CreatePromptFlashAttentionCnodeForBNSD(func_graph, node, q_trans, k_trans, v_trans, nullptr, num_head,
2610                                                      next_tokens, scale_value, fa_parm, num_head);
2611   }
2612   if (fa_node == nullptr) {
2613     return nullptr;
2614   }
2615   auto manager = Manage(func_graph);
2616   (void)manager->Replace(cnode, fa_node);
2617   MS_LOG(INFO) << "create prompt flash attention success for stable diffusion.";
2618   return nullptr;
2619 }
2620 
ParseFAParam() const2621 std::shared_ptr<FlashAttentionParm> FlashAttentionFusion::ParseFAParam() const {
2622   FlashAttentionParm fa_param;
2623   //  op_attrs=FlashAttention:input_layout:BSH;
2624   //           FlashAttention:seq_threshold:1024;
2625   //           FlashAttention:inner_precise:1;
2626   //           FlashAttention:sparse_mode:0
2627   if (op_attrs_map_.find("FlashAttention") != op_attrs_map_.end()) {
2628     auto attr_map = op_attrs_map_.at("FlashAttention");
2629     for (const auto &attr : attr_map) {
2630       auto attr_value = attr.second;
2631       if (attr.first == "input_layout") {
2632         if (strcmp(attr_value.c_str(), "BSH") == 0) {
2633           fa_param.format_bsh = true;
2634           MS_LOG(INFO) << "Use user config, FA input_layout is: " << fa_param.format_bsh;
2635         } else if (strcmp(attr_value.c_str(), "BNSD") == 0) {
2636           fa_param.format_bsh = false;
2637           MS_LOG(INFO) << "Use user config, FA input_layout is: " << fa_param.format_bsh;
2638         } else {
2639           MS_LOG(WARNING) << "FA input_layout only supports BSH and BNSD, but get " << attr_value;
2640           return nullptr;
2641         }
2642       } else if (attr.first == "seq_threshold") {
2643         int seq_threshold = std::atoi(attr_value.c_str());
2644         if (std::to_string(seq_threshold) == attr_value && seq_threshold >= 0) {
2645           fa_param.seq_threshold = mindspore::IntToLong(seq_threshold);
2646           MS_LOG(INFO) << "Use user config, FA seq_threshold is: " << fa_param.seq_threshold;
2647         } else {
2648           MS_LOG(WARNING) << "FA seq_threshold only supports (>0 and int) number, but get " << attr_value;
2649           return nullptr;
2650         }
2651       } else if (attr.first == "inner_precise") {
2652         if (FlashAttentionFusion::GetSocVersion() == kSocVersionAscend310P) {
2653           MS_LOG(WARNING) << "FA inner_precise is not supported on Ascend310P.";
2654           return nullptr;
2655         }
2656         int inner_precise = std::atoi(attr_value.c_str());
2657         if (std::to_string(inner_precise) == attr_value && (inner_precise == 0 || inner_precise == 1)) {
2658           MS_LOG(INFO) << "Use user config, FA inner_precise is: " << attr_value;
2659           fa_param.inner_precise = inner_precise;
2660         } else {
2661           MS_LOG(WARNING) << "FA inner_precise only supports 0 or 1, but get " << attr_value;
2662           return nullptr;
2663         }
2664       } else if (attr.first == "sparse_mode") {
2665         if (FlashAttentionFusion::GetSocVersion() != kSocVersionAscend310P) {
2666           MS_LOG(WARNING) << "FA sparse_mode is only supported on Ascend310P, but get env "
2667                           << FlashAttentionFusion::GetSocVersion();
2668           return nullptr;
2669         }
2670         int sparse_mode = std::atoi(attr_value.c_str());
2671         if (std::to_string(sparse_mode) == attr_value && (sparse_mode == 0 || sparse_mode == 10)) {
2672           MS_LOG(INFO) << "Use user config, FA sparse_mode is: " << attr_value;
2673           fa_param.inner_precise = sparse_mode;
2674         } else {
2675           MS_LOG(WARNING) << "FA sparse_mode only supports 0 or 10, but get " << attr_value;
2676           return nullptr;
2677         }
2678       } else {
2679         MS_LOG(WARNING) << "FA attr only supports input_layout, seq_threshold, inner_precise and sparse_mode, but get "
2680                         << attr.first;
2681         return nullptr;
2682       }
2683     }
2684   }
2685   auto fa_param_ptr = std::make_shared<FlashAttentionParm>(fa_param);
2686   return fa_param_ptr;
2687 }
2688 
Process(const std::string & patten_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr & equiv) const2689 AnfNodePtr FlashAttentionFusion::Process(const std::string &patten_name, const FuncGraphPtr &func_graph,
2690                                          const AnfNodePtr &node, const EquivPtr &equiv) const {
2691   MS_LOG(INFO) << "do flash attention fusion, pattern name: " << patten_name;
2692   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
2693   MS_CHECK_TRUE_RET(node != nullptr, nullptr);
2694   MS_CHECK_TRUE_RET(equiv != nullptr, nullptr);
2695   if (!utils::isa<CNodePtr>(node)) {
2696     MS_LOG(ERROR) << "this node is not cnode, node name: " << node->fullname_with_scope();
2697     return nullptr;
2698   }
2699   if (IsMarkedTrainOp(utils::cast<CNodePtr>(node))) {
2700     MS_LOG(ERROR) << "node is train op, can not fusion.";
2701     return nullptr;
2702   }
2703   auto manager = Manage(func_graph);
2704   MS_CHECK_TRUE_RET(manager != nullptr, nullptr);
2705   CNodePtr flash_attention_node = nullptr;
2706   auto fa_param_ptr = ParseFAParam();
2707   MS_CHECK_TRUE_RET(fa_param_ptr != nullptr, nullptr);
2708   if (patten_name == kNameFlashAttentionPatternForSDBSH) {
2709     MS_LOG(INFO) << "start create flash attention node for stable diffusion.";
2710     flash_attention_node = CreateFlashAttentionNodeForSD(patten_name, func_graph, node, equiv, fa_param_ptr);
2711   } else if (patten_name == kNameFlashAttentionPatternForPanGu) {
2712     MS_LOG(INFO) << "start create flash attention node for PanGu models.";
2713     flash_attention_node = CreateFlashAttentionNodeForPanGu(patten_name, func_graph, node, equiv, fa_param_ptr);
2714   } else if (patten_name == kNameFlashAttentionPatternForLLAMAPatternV1) {
2715     MS_LOG(INFO) << "start create flash attention node for LLAMAV1 Pattern V1.";
2716     flash_attention_node =
2717       CreateFlashAttentionNodeForLLAMAPatternV1(patten_name, func_graph, node, equiv, fa_param_ptr);
2718   } else if (patten_name == kNameFlashAttentionPatternForMsSDPseShift) {
2719     MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion pse-shift.";
2720     flash_attention_node = CreateFlashAttentionNodeForMsSDPseShift(patten_name, func_graph, node, equiv, fa_param_ptr);
2721   } else if (patten_name == kNameFlashAttentionPatternForLLAMAPatternV2) {
2722     MS_LOG(INFO) << "start create flash attention node for LLAMAV1 Pattern V2.";
2723     flash_attention_node =
2724       CreateFlashAttentionNodeForLLAMAPatternV2(patten_name, func_graph, node, equiv, fa_param_ptr);
2725   } else if (patten_name == kNameFlashAttentionPatternForBaiChuan) {
2726     MS_LOG(INFO) << "start create flash attention node for BaiChuan models.";
2727     flash_attention_node =
2728       CreateFlashAttentionNodeForBaiChuanPattern(patten_name, func_graph, node, equiv, fa_param_ptr);
2729   } else if (patten_name == kNameFlashAttentionPatternForVideoComposer) {
2730     MS_LOG(INFO) << "start create flash attention node for Video Composer models.";
2731     flash_attention_node = CreateFlashAttentionNodeForVideoComposer(patten_name, func_graph, node, equiv, fa_param_ptr);
2732   } else if (patten_name == kNameFlashAttentionPatternForMsSDXL) {
2733     MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion XL version.";
2734     flash_attention_node = CreateFlashAttentionNodeForMsSDXL(patten_name, func_graph, node, equiv, fa_param_ptr);
2735   } else if (patten_name == kNameFlashAttentionPatternForMsSD21) {
2736     MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion 2.1 version.";
2737     flash_attention_node = CreateFlashAttentionNodeForMsSD21(patten_name, func_graph, node, equiv, fa_param_ptr);
2738   } else if (patten_name == kNameFlashAttentionPatternForSDPreMul) {
2739     MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion PreMul.";
2740     flash_attention_node = CreateFlashAttentionNodeForSDPreMul(patten_name, func_graph, node, equiv, fa_param_ptr);
2741   } else if (patten_name == kNameFlashAttentionPatternForSDWithoutCast) {
2742     MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion without cast.";
2743     flash_attention_node = CreateFlashAttentionNodeForSDWithoutCast(patten_name, func_graph, node, equiv, fa_param_ptr);
2744   } else if (patten_name == kNameFlashAttentionPatternForSDEinsum) {
2745     MS_LOG(INFO) << "start create flash attention node for mindspore stable diffusion with Einsum.";
2746     flash_attention_node = CreateFlashAttentionNodeForSDEinsum(patten_name, func_graph, node, equiv, fa_param_ptr);
2747   } else {
2748     MS_LOG(ERROR) << " not pattern.";
2749   }
2750   if (flash_attention_node == nullptr) {
2751     MS_LOG(INFO) << "flash attention op not fusion.";
2752     return nullptr;
2753   }
2754   manager->Replace(node, flash_attention_node);
2755   MS_LOG(INFO) << "flash attention node fusion success, fusion node name: "
2756                << flash_attention_node->fullname_with_scope();
2757   return flash_attention_node;
2758 }
2759 }  // namespace mindspore::opt
2760