1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FLASH_ATTENTION_BASE_FUSION_H_ 18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FLASH_ATTENTION_BASE_FUSION_H_ 19 20 #include <string> 21 #include <memory> 22 #include <vector> 23 #include <map> 24 #include <unordered_map> 25 #include "tools/optimizer/common/multiple_pattern_process_pass.h" 26 #include "tools/optimizer/common/gllo_utils.h" 27 namespace mindspore { 28 namespace opt { 29 struct FlashAttentionParm { 30 bool format_bsh = false; 31 int64_t seq_threshold = 0; 32 int inner_precise = 1; 33 int sparse_mode = 0; 34 }; 35 /* 36 * 37 * -------------------------------------------------------------------------------------------------------- 38 * Pattern 1: | Pattern 2: 39 * transpose input[0] is input[K] -> transpose | transpose input[0] is input[K] -> transpose 40 * matmul input[0] is input[Q] -> matmul | matmul input[0] is input[Q] -> matmul 41 * mul | mul 42 * cast | softMax 43 * softMax | cast 44 * cast | matmul input[0] is input[V] -> matmul 45 * matmul input[0] is input[V] -> matmul | 46 * -------------------------------------------------------------------------------------------------------- 47 * 48 */ 49 class FlashAttentionFusion : public MultiplePatternProcessPass { 50 public: 51 explicit FlashAttentionFusion(std::map<std::string, std::map<std::string, std::string>> op_attrs_map, 52 const std::string &name = "FlashAttentionFusion", bool multigraph = true) MultiplePatternProcessPass(name,multigraph)53 : MultiplePatternProcessPass(name, multigraph) { 54 op_attrs_map_ = op_attrs_map; 55 } 56 57 ~FlashAttentionFusion() override = default; 58 59 std::unordered_map<std::string, VectorRef> DefinePatterns() const override; 60 61 AnfNodePtr Process(const std::string &, const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 62 SetSocVersion(const std::string & soc_version)63 static void SetSocVersion(const std::string &soc_version) { soc_version_ = soc_version; } 64 GetSocVersion()65 static std::string GetSocVersion() { return soc_version_; } 66 67 private: 68 std::map<std::string, std::map<std::string, std::string>> op_attrs_map_; 69 70 CNodePtr CreatePromptFlashAttentionCnodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 71 const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v, 72 const AnfNodePtr &atten_mask, int64_t num_heads, int64_t next_token, 73 float scale_value, const std::shared_ptr<FlashAttentionParm> &fa_parm, 74 int64_t num_key_value_heads = 1) const; 75 76 CNodePtr CreatePromptFlashAttentionCnodeForBNSDWithPse(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 77 const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v, 78 const AnfNodePtr &atten_mask, const AnfNodePtr &pse, 79 int64_t num_heads, int64_t next_token, float scale_value, 80 const std::shared_ptr<FlashAttentionParm> &fa_parm, 81 int64_t num_key_value_heads = 1) const; 82 83 CNodePtr CreatePromptFlashAttentionCnodeForBSH(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 84 const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v, 85 const AnfNodePtr &atten_mask, int64_t num_heads, int64_t next_token, 86 float scale_value, 87 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 88 89 CNodePtr CreateIncreFlashAttentionCnodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 90 const AnfNodePtr &q, const AnfNodePtr &k, const AnfNodePtr &v, 91 const AnfNodePtr &atten_mask, int64_t num_heads, float scale_value, 92 int64_t num_key_value_heads) const; 93 CNodePtr CreateFlashAttentionNodeForMsSD21(const std::string &pattern_name, const FuncGraphPtr &func_graph, 94 const AnfNodePtr &node, const EquivPtr &equiv, 95 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 96 CNodePtr CreateFlashAttentionNodeForMsSDPseShift(const std::string &pattern_name, const FuncGraphPtr &func_graph, 97 const AnfNodePtr &node, const EquivPtr &equiv, 98 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 99 CNodePtr CreateFlashAttentionNodeForMsSDXL(const std::string &pattern_name, const FuncGraphPtr &func_graph, 100 const AnfNodePtr &node, const EquivPtr &equiv, 101 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 102 CNodePtr CreateFlashAttentionNodeForVideoComposer(const std::string &pattern_name, const FuncGraphPtr &func_graph, 103 const AnfNodePtr &node, const EquivPtr &equiv, 104 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 105 CNodePtr CreateFlashAttentionNodeForSD(const std::string &pattern_name, const FuncGraphPtr &func_graph, 106 const AnfNodePtr &node, const EquivPtr &equiv, 107 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 108 CNodePtr CreateFlashAttentionNodeForSDPreMul(const std::string &pattern_name, const FuncGraphPtr &func_graph, 109 const AnfNodePtr &node, const EquivPtr &equiv, 110 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 111 CNodePtr CreateFlashAttentionNodeForSDWithoutCast(const std::string &pattern_name, const FuncGraphPtr &func_graph, 112 const AnfNodePtr &node, const EquivPtr &equiv, 113 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 114 CNodePtr CreateFlashAttentionNodeForPanGu(const std::string &pattern_name, const FuncGraphPtr &func_graph, 115 const AnfNodePtr &node, const EquivPtr &equiv, 116 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 117 CNodePtr CreateFlashAttentionNodeForLLAMAPatternV1(const std::string &pattern_name, const FuncGraphPtr &func_graph, 118 const AnfNodePtr &node, const EquivPtr &equiv, 119 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 120 CNodePtr CreateFlashAttentionNodeForLLAMAPatternV2(const std::string &pattern_name, const FuncGraphPtr &func_graph, 121 const AnfNodePtr &node, const EquivPtr &equiv, 122 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 123 CNodePtr CreateFlashAttentionNodeForBaiChuanPattern(const std::string &pattern_name, const FuncGraphPtr &func_graph, 124 const AnfNodePtr &node, const EquivPtr &equiv, 125 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 126 CNodePtr CreateFlashAttentionNodeForSDEinsum(const std::string &pattern_name, const FuncGraphPtr &func_graph, 127 const AnfNodePtr &node, const EquivPtr &equiv, 128 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 129 130 CNodePtr CreatePadCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int32_t pad_size, 131 const std::string &node_name = "") const; 132 CNodePtr CreateSliceCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, int32_t slice_size) const; 133 CNodePtr GetSDDynamicShapeParam(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; 134 float GetScaleValueForDynamicShape(const AnfNodePtr &mul_const_input) const; 135 CNodePtr CreateFAForSD15(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q_trans, 136 const AnfNodePtr &k_trans, const AnfNodePtr &v_trans, int64_t num_head, int64_t next_token, 137 float scale_value, const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 138 CNodePtr CreateFAWithPadAndPse(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const AnfNodePtr &q_trans, 139 const AnfNodePtr &k_trans, const AnfNodePtr &v_trans, const AnfNodePtr &pse, 140 int64_t num_head, int64_t next_token, float scale_value, 141 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 142 CNodePtr CreateGQACNodeForBNSD(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const CNodePtr &matmul_1, 143 const CNodePtr &matmul_2, const CNodePtr &attention_mask_mul, 144 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 145 CNodePtr CreateFAForBNSDWithAttenMask(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 146 const CNodePtr &qk_matmul, const CNodePtr &v_matmul, 147 const CNodePtr &attention_mask_mul, 148 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 149 150 CNodePtr CreateFACNodeWithoutAttenMask(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 151 const CNodePtr &qk_matmul, const CNodePtr &v_matmul, 152 const CNodePtr &attention_mask_mul, 153 const std::shared_ptr<FlashAttentionParm> &fa_parm) const; 154 155 const VectorRef DefineFlashAttentionPatternForMsSD21() const; 156 157 /* 158 * -------------------------------------------------- 159 * Pattern PseShift: | 160 * trans input[1] is reshape[input[K]] -> trans | 161 * matmul input[1] is reshape[input[Q]] -> matmul | 162 * mul | 163 * add | 164 * softMax | 165 * cast | 166 * matmul input[2] is reshape[input[V]] -> matmul | 167 * reshape | 168 * -------------------------------------------------- 169 */ 170 const VectorRef DefineFlashAttentionPatternForMsSDPseShift() const; 171 172 const VectorRef DefineFlashAttentionPatternForVideoComposer() const; 173 const VectorRef DefineFlashAttentionPatternForMsSDXL() const; 174 const VectorRef DefineFlashAttentionPatternForSDBNSD() const; 175 const VectorRef DefineFlashAttentionPatternForSDBSH() const; 176 const VectorRef DefineFlashAttentionPatternForSDPreMul() const; 177 const VectorRef DefineFlashAttentionPatternForSDWithoutCast() const; 178 const VectorRef DefineFlashAttentionPatternForPanGu() const; 179 const VectorRef DefineFlashAttentionPatternForLLAMAPatternV1() const; 180 const VectorRef DefineFlashAttentionPatternForLLAMAPatternV2() const; 181 const VectorRef DefineFlashAttentionPatternForBaiChuan() const; 182 183 /* 184 * -------------------------------------------------- 185 * Pattern SD with Einsum: | 186 * (Node: Einsum is replaced by matmul | 187 * in the onnx parser) | 188 * input[K]| 189 * reshape | 190 * einsum input[0] is reshape[input[Q]] -> einsum | 191 * mul | 192 * softMax | 193 * einsum input[1] is reshape[input[V]] -> einsum | 194 * reshape | 195 * -------------------------------------------------- 196 */ 197 const VectorRef DefineFlashAttentionPatternForSDEinsum() const; 198 199 std::shared_ptr<FlashAttentionParm> ParseFAParam() const; 200 201 private: 202 static std::string soc_version_; 203 }; 204 } // namespace opt 205 } // namespace mindspore 206 #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_FLASH_ATTENTION_BASE_FUSION_H_ 207