• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_FLASH_ATTENTION_FUSION_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_FLASH_ATTENTION_FUSION_H_
18 
19 #include <string>
20 #include <memory>
21 #include <vector>
22 #include <unordered_map>
23 #include "include/backend/optimizer/optimizer.h"
24 
25 namespace mindspore {
26 namespace opt {
27 class FlashAttentionFusion : public PatternProcessPass {
28  public:
29   explicit FlashAttentionFusion(const std::string &name = "", bool multigraph = true)
PatternProcessPass(name,multigraph)30       : PatternProcessPass(name, multigraph) {}
31   ~FlashAttentionFusion() override = default;
32   const BaseRef DefinePattern() const override;
33   const AnfNodePtr Process(const FuncGraphPtr &graph, const AnfNodePtr &node, const EquivPtr &equiv) const override;
34 
35  protected:
36   CNodePtr CreatePromptFlashAttentionCnodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
37                                                   const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
38                                                   const AnfNodePtr &atten_mask, const int64_t num_heads,
39                                                   const int64_t next_token, const float scale_value,
40                                                   const int64_t num_key_value_heads) const;
41 
42  private:
43   virtual CNodePtr CreateFlashAttentionNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
44                                             const EquivPtr &equiv) const = 0;
45   virtual const VectorRef DefineFlashAttentionPattern() const = 0;
46 };
47 
48 class FlashAttentionFusionV1 : public FlashAttentionFusion {
49  public:
50   explicit FlashAttentionFusionV1(bool multigraph = true)
51       : FlashAttentionFusion("FlashAttentionFusionV1", multigraph) {}
52   ~FlashAttentionFusionV1() override = default;
53 
54  private:
55   CNodePtr CreateFlashAttentionNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
56                                     const EquivPtr &equiv) const override;
57   const VectorRef DefineFlashAttentionPattern() const override;
58 };
59 
60 class FlashAttentionFusionV2 : public FlashAttentionFusion {
61  public:
62   explicit FlashAttentionFusionV2(bool multigraph = true)
63       : FlashAttentionFusion("FlashAttentionFusionV2", multigraph) {}
64   ~FlashAttentionFusionV2() override = default;
65 
66  private:
67   CNodePtr CreateFlashAttentionNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
68                                     const EquivPtr &equiv) const override;
69   const VectorRef DefineFlashAttentionPattern() const override;
70 };
71 }  // namespace opt
72 }  // namespace mindspore
73 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_FLASH_ATTENTION_FUSION_H_
74