• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FLASH_ATTENTION_BASE_FUSION_H_
18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FLASH_ATTENTION_BASE_FUSION_H_
19 
20 #include <string>
21 #include <memory>
22 #include <vector>
23 #include <map>
24 #include <unordered_map>
25 #include "tools/optimizer/common/multiple_pattern_process_pass.h"
26 #include "tools/optimizer/common/gllo_utils.h"
27 namespace mindspore {
28 namespace opt {
29 struct FlashAttentionParm {
30   bool format_bsh = false;
31   int64_t seq_threshold = 0;
32   int inner_precise = 1;
33   int sparse_mode = 0;
34 };
35 /*
36  *
37  * --------------------------------------------------------------------------------------------------------
38  *  Pattern 1:                                      |   Pattern 2:
39  *    transpose input[0] is input[K] -> transpose   |     transpose input[0] is input[K] -> transpose
40  *      matmul  input[0] is input[Q] ->   matmul    |       matmul  input[0] is input[Q] ->   matmul
41  *                                         mul      |                                          mul
42  *                                        cast      |                                        softMax
43  *                                       softMax    |                                         cast
44  *                                        cast      |       matmul  input[0] is input[V] ->  matmul
45  *      matmul  input[0] is input[V] ->  matmul     |
46  * --------------------------------------------------------------------------------------------------------
47  *
48  */
49 class FlashAttentionFusion : public MultiplePatternProcessPass {
50  public:
51   explicit FlashAttentionFusion(std::map<std::string, std::map<std::string, std::string>> op_attrs_map,
52                                 const std::string &name = "FlashAttentionFusion", bool multigraph = true)
MultiplePatternProcessPass(name,multigraph)53       : MultiplePatternProcessPass(name, multigraph) {
54     op_attrs_map_ = op_attrs_map;
55   }
56 
57   ~FlashAttentionFusion() override = default;
58 
59   std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
60 
61   AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
62 
SetSocVersion(const std::string & soc_version)63   static void SetSocVersion(const std::string &soc_version) { soc_version_ = soc_version; }
64 
GetSocVersion()65   static std::string GetSocVersion() { return soc_version_; }
66 
67  private:
68   std::map<std::string, std::map<std::string, std::string>> op_attrs_map_;
69 
70   CNodePtr CreatePromptFlashAttentionCnodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
71                                                   const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
72                                                   const AnfNodePtr &atten_mask, int64_t num_heads, int64_t next_token,
73                                                   float scale_value, const std::shared_ptr<FlashAttentionParm> &fa_parm,
74                                                   int64_t num_key_value_heads = 1) const;
75 
76   CNodePtr CreatePromptFlashAttentionCnodeForBNSDWithPse(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
77                                                          const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
78                                                          const AnfNodePtr &atten_mask, const AnfNodePtr &pse,
79                                                          int64_t num_heads, int64_t next_token, float scale_value,
80                                                          const std::shared_ptr<FlashAttentionParm> &fa_parm,
81                                                          int64_t num_key_value_heads = 1) const;
82 
83   CNodePtr CreatePromptFlashAttentionCnodeForBSH(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
84                                                  const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
85                                                  const AnfNodePtr &atten_mask, int64_t num_heads, int64_t next_token,
86                                                  float scale_value,
87                                                  const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
88 
89   CNodePtr CreateIncreFlashAttentionCnodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
90                                                  const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v,
91                                                  const AnfNodePtr &atten_mask, int64_t num_heads, float scale_value,
92                                                  int64_t num_key_value_heads) const;
93   CNodePtr CreateFlashAttentionNodeForMsSD21(const std::string &pattern_name, const FuncGraphPtr &func_graph,
94                                              const AnfNodePtr &node, const EquivPtr &equiv,
95                                              const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
96   CNodePtr CreateFlashAttentionNodeForMsSDPseShift(const std::string &pattern_name, const FuncGraphPtr &func_graph,
97                                                    const AnfNodePtr &node, const EquivPtr &equiv,
98                                                    const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
99   CNodePtr CreateFlashAttentionNodeForMsSDXL(const std::string &pattern_name, const FuncGraphPtr &func_graph,
100                                              const AnfNodePtr &node, const EquivPtr &equiv,
101                                              const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
102   CNodePtr CreateFlashAttentionNodeForVideoComposer(const std::string &pattern_name, const FuncGraphPtr &func_graph,
103                                                     const AnfNodePtr &node, const EquivPtr &equiv,
104                                                     const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
105   CNodePtr CreateFlashAttentionNodeForSD(const std::string &pattern_name, const FuncGraphPtr &func_graph,
106                                          const AnfNodePtr &node, const EquivPtr &equiv,
107                                          const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
108   CNodePtr CreateFlashAttentionNodeForSDPreMul(const std::string &pattern_name, const FuncGraphPtr &func_graph,
109                                                const AnfNodePtr &node, const EquivPtr &equiv,
110                                                const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
111   CNodePtr CreateFlashAttentionNodeForSDWithoutCast(const std::string &pattern_name, const FuncGraphPtr &func_graph,
112                                                     const AnfNodePtr &node, const EquivPtr &equiv,
113                                                     const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
114   CNodePtr CreateFlashAttentionNodeForPanGu(const std::string &pattern_name, const FuncGraphPtr &func_graph,
115                                             const AnfNodePtr &node, const EquivPtr &equiv,
116                                             const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
117   CNodePtr CreateFlashAttentionNodeForLLAMAPatternV1(const std::string &pattern_name, const FuncGraphPtr &func_graph,
118                                                      const AnfNodePtr &node, const EquivPtr &equiv,
119                                                      const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
120   CNodePtr CreateFlashAttentionNodeForLLAMAPatternV2(const std::string &pattern_name, const FuncGraphPtr &func_graph,
121                                                      const AnfNodePtr &node, const EquivPtr &equiv,
122                                                      const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
123   CNodePtr CreateFlashAttentionNodeForBaiChuanPattern(const std::string &pattern_name, const FuncGraphPtr &func_graph,
124                                                       const AnfNodePtr &node, const EquivPtr &equiv,
125                                                       const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
126   CNodePtr CreateFlashAttentionNodeForSDEinsum(const std::string &pattern_name, const FuncGraphPtr &func_graph,
127                                                const AnfNodePtr &node, const EquivPtr &equiv,
128                                                const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
129 
130   CNodePtr CreatePadCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int32_t pad_size,
131                           const std::string &node_name = "") const;
132   CNodePtr CreateSliceCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int32_t slice_size) const;
133   CNodePtr GetSDDynamicShapeParam(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const;
134   float GetScaleValueForDynamicShape(const AnfNodePtr &mul_const_input) const;
135   CNodePtr CreateFAForSD15(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q_trans,
136                            const AnfNodePtr &k_trans, const AnfNodePtr &v_trans, int64_t num_head, int64_t next_token,
137                            float scale_value, const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
138   CNodePtr CreateFAWithPadAndPse(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q_trans,
139                                  const AnfNodePtr &k_trans, const AnfNodePtr &v_trans, const AnfNodePtr &pse,
140                                  int64_t num_head, int64_t next_token, float scale_value,
141                                  const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
142   CNodePtr CreateGQACNodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const CNodePtr &matmul_1,
143                                  const CNodePtr &matmul_2, const CNodePtr &attention_mask_mul,
144                                  const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
145   CNodePtr CreateFAForBNSDWithAttenMask(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
146                                         const CNodePtr &qk_matmul, const CNodePtr &v_matmul,
147                                         const CNodePtr &attention_mask_mul,
148                                         const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
149 
150   CNodePtr CreateFACNodeWithoutAttenMask(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
151                                          const CNodePtr &qk_matmul, const CNodePtr &v_matmul,
152                                          const CNodePtr &attention_mask_mul,
153                                          const std::shared_ptr<FlashAttentionParm> &fa_parm) const;
154 
155   const VectorRef DefineFlashAttentionPatternForMsSD21() const;
156 
157   /*
158    * --------------------------------------------------
159    *  Pattern PseShift:                               |
160    *   trans input[1] is reshape[input[K]] -> trans   |
161    *  matmul input[1] is reshape[input[Q]] -> matmul  |
162    *                                          mul     |
163    *                                          add     |
164    *                                          softMax |
165    *                                          cast    |
166    * matmul input[2] is reshape[input[V]] ->  matmul  |
167    *                                          reshape |
168    * --------------------------------------------------
169    */
170   const VectorRef DefineFlashAttentionPatternForMsSDPseShift() const;
171 
172   const VectorRef DefineFlashAttentionPatternForVideoComposer() const;
173   const VectorRef DefineFlashAttentionPatternForMsSDXL() const;
174   const VectorRef DefineFlashAttentionPatternForSDBNSD() const;
175   const VectorRef DefineFlashAttentionPatternForSDBSH() const;
176   const VectorRef DefineFlashAttentionPatternForSDPreMul() const;
177   const VectorRef DefineFlashAttentionPatternForSDWithoutCast() const;
178   const VectorRef DefineFlashAttentionPatternForPanGu() const;
179   const VectorRef DefineFlashAttentionPatternForLLAMAPatternV1() const;
180   const VectorRef DefineFlashAttentionPatternForLLAMAPatternV2() const;
181   const VectorRef DefineFlashAttentionPatternForBaiChuan() const;
182 
183   /*
184    * --------------------------------------------------
185    *  Pattern SD with Einsum:                         |
186    *  (Node: Einsum is replaced by matmul             |
187    *         in the onnx parser)                      |
188    *                                          input[K]|
189    *                                          reshape |
190    * einsum input[0] is reshape[input[Q]] ->  einsum  |
191    *                                          mul     |
192    *                                          softMax |
193    * einsum input[1] is reshape[input[V]] ->  einsum  |
194    *                                          reshape |
195    * --------------------------------------------------
196    */
197   const VectorRef DefineFlashAttentionPatternForSDEinsum() const;
198 
199   std::shared_ptr<FlashAttentionParm> ParseFAParam() const;
200 
201  private:
202   static std::string soc_version_;
203 };
204 }  // namespace opt
205 }  // namespace mindspore
206 #endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FLASH_ATTENTION_BASE_FUSION_H_
207