1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_ 17 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <vector> 23 #include "tools/optimizer/common/multiple_pattern_process_pass.h" 24 #include "include/common/utils/utils.h" 25 #include "include/errorcode.h" 26 #include "ops/encoder_layer.h" 27 #include "tools/optimizer/fusion/multi_head_attention_fusion.h" 28 #include "ops/fusion/layer_norm_fusion.h" 29 #include "ops/fusion/activation.h" 30 31 namespace mindspore { 32 namespace opt { 33 class EncoderLayerFusion : public MultiplePatternProcessPass { 34 public: 35 explicit EncoderLayerFusion(bool embedding_layer = false, const std::string &name = "EncoderLayerFusion", 36 bool multigraph = true) MultiplePatternProcessPass(name,multigraph)37 : MultiplePatternProcessPass(name, multigraph) { 38 embedding_layer_ = embedding_layer; 39 } 40 41 ~EncoderLayerFusion() override = default; 42 43 AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &, 44 const EquivPtr &) const override; 45 std::unordered_map<std::string, VectorRef> DefinePatterns() const override; 46 47 protected: 48 virtual bool Init() const; 49 50 private: 51 const std::string kPatternEncoderLayerPreNormUsePast = "PatternEncoderLayerPreNormUsePast"; 52 const std::string kPatternEncoderLayerUsePastWithLastNorm = "PatternEncoderLayerPreNormUsePastWithLastNorm"; 53 const std::string kPatternEncoderLayerPost = "PatternTEncoderLayerPost"; 54 const std::string kPatternEncoderLayerPre = "PatternTEncoderLayerPre"; 55 const std::string kPatternEncoderLayerPostNorm = "PatternTEncoderLayerPostNorm"; 56 const std::string kPatternEncoderLayerPreNorm = "PatternTEncoderLayerPreNorm"; 57 const std::string kPatternEncoderLayerT5Post = "PatternEncoderLayerT5Post"; 58 const std::string kPatternEncoderLayerT5Pre = "PatternEncoderLayerT5Pre"; 59 const std::string kPatternEncoderLayerNormT5Pre = "PatternEncoderLayerNormT5Pre"; 60 const std::string kPatternQueryLayerUsePast = "PatternQueryLayerUsePast"; 61 const std::string kPatternSigmaDistributed = "PatternSigmaDistributed"; 62 const std::string kPatternSigmaDistributedEmbedding = "PatternSigmaDistributedEmbedding"; 63 const std::string kPatternSigmaMoeDistributed = "PatternSigmaMoeDistributed"; 64 const std::string kPatternSigmaMoeWithLastLayerNormDistributed = "PatternSigmaMoeWithLastLayerNormDistributed"; 65 const std::string kPatternSigmaWithLastLayerNormDistributed = "PatternSigmaWithLastLayerNormDistributed"; 66 const std::string kPatternSigmaQueryLayerDistributed = "PatternSigmaQueryLayerDistributed"; 67 const std::string kPatternDistributedAlpha = "PatternDistributedAlpha"; 68 const std::string kPatternDistributedAlphaWithLastLayerNorm = "PatternDistributedAlphaWithLastLayerNorm"; 69 const std::string kPatternQueryLayerUsePastDistributed = "PatternQueryLayerUsePastDistributed"; 70 const std::string kPatternSigma = "kPatternSigma"; 71 const std::string kPatternSigmaEmbedding = "kPatternSigmaEmbedding"; 72 const std::string kPatternSigmaQuery = "kPatternSigmaQuery"; 73 const std::string kPatternSigmaMoe = "kPatternSigmaMoe"; 74 const std::string kPatternSigmaMoeWithLastLayerNorm = "PatternSigmaMoeWithLastLayerNorm"; 75 const std::string kPatternSigmaWithLastLayerNorm = "PatternSigmaWithLastLayerNorm"; 76 const std::string kPatternSigmaQueryLayerMoe = "PatternSigmaQueryLayerMoe"; 77 const std::string kPatternSigmaDistributedMB = "PatternSigmaDistributedMB"; 78 const std::string kPatternSigmaDistributedEmbeddingMB = "PatternSigmaDistributedEmbeddingMB"; 79 const std::string kPatternSigmaMoeWithLastLayerNormDistributedMB = "PatternSigmaMoeWithLastLayerNormDistributedMB"; 80 const std::string kPatternSigmaWithLastLayerNormDistributedMB = "PatternSigmaWithLastLayerNormMB"; 81 const std::string kPatternSigmaQueryLayerDistributedMB = "PatternSigmaQueryLayerDistributedMB"; 82 const std::string kPatternSigmaMoeDistributedMB = "PatternSigmaMoeDistributedMB"; 83 const std::string kPatternSigmaDistributedMBGELU = "PatternSigmaDistributedMBGELU"; 84 const std::string kPatternSigmaDistributedEmbeddingMBGELU = "PatternSigmaDistributedEmbeddingMBGELU"; 85 const std::string kPatternSigmaMoeWithLastLayerNormDistributedMBGELU = 86 "PatternSigmaMoeWithLastLayerNormDistributedMBGELU"; 87 const std::string kPatternSigmaWithLastLayerNormDistributedMBGELU = "PatternSigmaWithLastLayerNormMBGELU"; 88 const std::string kPatternSigmaQueryLayerDistributedMBGELU = "PatternSigmaQueryLayerDistributedMBGELU"; 89 const std::string kPatternSigmaMoeDistributedMBGELU = "PatternSigmaMoeDistributedMBGELU"; 90 const std::string kPatternSigmaDistributedMBFirst = "PatternSigmaMoeDistributedMBFirst"; 91 const std::string kPatternSigmaFirst = "kPatternSigmaFirsts"; 92 const std::string kPatternSigmaQueryLayerDistributedMBMoe = "kPatternSigmaQueryLayerDistributedMBMoe"; 93 const std::string kPatternSigmaQueryLayerDistributedMoe = "kPatternSigmaQueryLayerDistributedMoe"; 94 const std::string kPatternSigmaEmbeddingDistributed = "kPatternSigmaEmbeddingDistributed"; 95 96 VectorRef DefinePatternEncoderLayer(bool post_layernorm, bool layernorm_fusion, bool is_position_bias_, bool mask, 97 bool is_layer_norm) const; 98 VectorRef DefinePatternEncoderSigma(bool moe, bool use_past, bool distributed, bool is_layer_norm, bool query_layer, 99 bool multi_batch, bool first_encoder, bool gelu) const; 100 101 VectorRef DefinePatternEncoderAlpha(bool moe, bool distributed, bool is_layer_norm, bool query_layer, 102 bool use_past) const; 103 VectorRef getTuple(bool post_layernorm, bool layernorm_fusion, bool is_position_bias) const; 104 VectorRef DefineLayerNorm(bool is_position_bias, BaseRef input, VarPtr gamma, VarPtr beta, VarPtr eps) const; 105 CNodePtr CreateMaskedEncoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, 106 const AnfNodePtr &node, bool post_layernorm, bool mask) const; 107 AnfNodePtr GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv, VarPtr node_name) const; 108 bool IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const VarPtr &input_prim) const; 109 lite::STATUS GetEps(const EquivPtr &equiv, VarPtr node_name, float *eps) const; 110 lite::STATUS CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num, int *head_size, 111 float *eps1, float *eps2, float *eps3, float *scale) const; 112 std::shared_ptr<ops::EncoderLayer> CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv, 113 int64_t ffn_hidden_size, int64_t expert_num, int64_t expert_offset, 114 float capacity_factor) const; 115 VectorRef DefinePatternInitReset(VectorRef input, bool is_value_reset = false, bool is_key_reset = false) const; 116 VectorRef DefinePatternMultiBatch(VectorRef input) const; 117 BaseRef DefineBatchValidLength(const BaseRef &input) const; 118 VectorRef DefinePatternMoERouter(VectorRef input_layernorm) const; 119 VectorRef DefinePatternMoE(VectorRef input_layernorm, bool multi_batch, bool gelu) const; 120 VectorRef DefinePatternSigmaFfn(BaseRef input, bool gelu, bool distributed) const; 121 VectorRef DefinePatternMoETopKRouter(VectorRef input) const; 122 VectorRef DefinePatternMoEFfn(VectorRef input_reshape, bool gelu) const; 123 VectorRef DefineDependKV(VectorRef input_layernorm, VectorRef deppend_v_input, bool moe) const; 124 VectorRef DefineFfn(VectorRef input) const; 125 VectorRef DefineFirstEncoder(bool distributed) const; 126 lite::STATUS InitAttributes(AnfNodePtr k_past, AnfNodePtr begin_expert_ids, AnfNodePtr weight_m, 127 AnfNodePtr expert_capacity_node, int *ffn_hidden_size, int *expert_num, 128 int *expert_offset, float *capacity_factor) const; 129 void InitParams(bool post_layernorm, bool layernorm_fusion, bool is_position_bias, bool mask, bool is_layer_norm, 130 bool use_past, bool query_layer, bool sigma, bool distributed, bool moe) const; 131 bool IsUsePast(const std::string &pattern_name) const; 132 bool IsUsePastMB(const std::string &pattern_name) const; 133 bool IsUsePastAlpha(const std::string &pattern_name) const; 134 bool IsLastLayerNorm(const std::string &pattern_name) const; 135 bool IsLayerNormFusion(const std::string &pattern_name) const; 136 bool IsMoe(const std::string &pattern_name) const; 137 bool IsFastGelu(const std::string &pattern_name) const; 138 bool IsFastGeluDistributed(const std::string &pattern_name) const; 139 bool IsQueryLayer(const std::string &pattern_name) const; 140 141 protected: 142 mutable VarPtr input_{nullptr}; 143 mutable VarPtr expert_ids_input_{nullptr}; 144 mutable VarPtr expert_ids_{nullptr}; 145 mutable VarPtr expert_capacity_{nullptr}; 146 mutable VarPtr begin_expert_ids_{nullptr}; 147 mutable VarPtr position_bias_{nullptr}; 148 mutable VarPtr beta1_{nullptr}; 149 mutable VarPtr gamma1_{nullptr}; 150 mutable VarPtr beta2_{nullptr}; 151 mutable VarPtr gamma2_{nullptr}; 152 mutable VarPtr beta3_{nullptr}; 153 mutable VarPtr gamma3_{nullptr}; 154 mutable VarPtr weight_attn_qkv_{nullptr}; 155 mutable VarPtr weight_attn_q_{nullptr}; 156 mutable VarPtr weight_attn_o_{nullptr}; 157 mutable VarPtr weight_m_{nullptr}; 158 mutable VarPtr weight_p_{nullptr}; 159 mutable VarPtr bias_attn_qkv_{nullptr}; 160 mutable VarPtr bias_attn_o_{nullptr}; 161 mutable VarPtr bias_m_{nullptr}; 162 mutable VarPtr bias_p_{nullptr}; 163 mutable VarPtr mask_{nullptr}; 164 mutable VarPtr is_attention_{nullptr}; 165 mutable VarPtr is_layernorm1_{nullptr}; 166 mutable VarPtr is_layernorm2_{nullptr}; 167 mutable VarPtr is_layernorm3_{nullptr}; 168 mutable ActType act_type_{ActType::ActType_No}; 169 mutable VarPtr is_act_{nullptr}; 170 mutable VarPtr eps1_{nullptr}; 171 mutable VarPtr eps2_{nullptr}; 172 mutable VarPtr eps3_{nullptr}; 173 mutable VarPtr init_reset_{nullptr}; 174 mutable VarPtr k_past_{nullptr}; 175 mutable VarPtr v_past_{nullptr}; 176 mutable VarPtr input_q_{nullptr}; 177 mutable VarPtr batch_valid_length_{nullptr}; 178 mutable VarPtr embedding_table_{nullptr}; 179 mutable VarPtr weight_router_{nullptr}; 180 181 mutable VarPtr position_ids_{nullptr}; 182 mutable VarPtr embedding_table_input_{nullptr}; 183 mutable VarPtr current_index_{nullptr}; 184 mutable VarPtr embedding_table_pos_{nullptr}; 185 186 mutable bool is_position_bias_{false}; 187 mutable bool is_post_layernorm_{false}; 188 mutable bool is_layernorm_fusion_{false}; 189 mutable bool is_layernorm_{false}; 190 mutable bool is_use_past_{false}; 191 mutable bool is_query_layer_{false}; 192 mutable bool is_sigma_{false}; 193 mutable bool is_moe_{false}; 194 mutable bool is_distributed_{false}; 195 mutable bool is_fast_gelu_{false}; 196 mutable bool is_embedding_layer_{false}; 197 198 mutable bool embedding_layer_{false}; 199 }; 200 } // namespace opt 201 } // namespace mindspore 202 #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_ 203