1 // /** 2 // * Copyright 2023 Huawei Technologies Co., Ltd 3 // * 4 // * Licensed under the Apache License, Version 2.0 (the "License"); 5 // * you may not use this file except in compliance with the License. 6 // * You may obtain a copy of the License at 7 // * 8 // * http://www.apache.org/licenses/LICENSE-2.0 9 // * 10 // * Unless required by applicable law or agreed to in writing, software 11 // * distributed under the License is distributed on an "AS IS" BASIS, 12 // * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // * See the License for the specific language governing permissions and 14 // * limitations under the License. 15 // */ 16 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_DECODER_LAYER_FUSION_H_ 17 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_DECODER_LAYER_FUSION_H_ 18 19 #include <memory> 20 #include <string> 21 #include <unordered_map> 22 #include <vector> 23 #include "tools/optimizer/common/multiple_pattern_process_pass.h" 24 #include "include/common/utils/utils.h" 25 #include "include/errorcode.h" 26 #include "ops/decoder_layer.h" 27 #include "ops/fusion/layer_norm_fusion.h" 28 #include "ops/fusion/activation.h" 29 #include "tools/optimizer/fusion/multi_head_attention_fusion.h" 30 31 namespace mindspore { 32 namespace opt { 33 class DecoderLayerFusion : public MultiplePatternProcessPass { 34 public: 35 explicit DecoderLayerFusion(const std::string &name = "DecoderLayerFusion", bool multigraph = true) MultiplePatternProcessPass(name,multigraph)36 : MultiplePatternProcessPass(name, multigraph) {} 37 38 ~DecoderLayerFusion() override = default; 39 40 AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &, 41 const EquivPtr &) const override; 42 std::unordered_map<std::string, VectorRef> DefinePatterns() const override; 43 44 protected: 45 virtual bool Init() const; 46 47 private: 48 VectorRef DefinePatternDecoderLayer(bool post_layernorm, bool layernorm_fusion, bool is_position_bias, bool mask, 49 bool is_layer_norm) const; 50 VectorRef getTuple(bool post_layernorm, bool layernorm_fusion, bool is_position_bias) const; 51 VectorRef DefineLayerNorm(VectorRef input, VarPtr gamma, VarPtr beta, VarPtr eps) const; 52 CNodePtr CreateMaskedDecoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, 53 const AnfNodePtr &node, bool post_layernorm, bool mask) const; 54 std::shared_ptr<ops::DecoderLayer> CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv, 55 bool post_layernorm, int64_t ffn_hidden_size) const; 56 lite::STATUS CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num, int *head_size, 57 float *eps1, float *eps2, float *eps3, float *eps4, bool *is_position_bias1, 58 bool *is_position_bias2, float *scale1, float *scale2) const; 59 AnfNodePtr GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv, VarPtr node_name) const; 60 bool IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; 61 lite::STATUS GetEps(const EquivPtr &equiv, VarPtr node_name, float *eps) const; 62 VectorRef DefineDecoderLayerNorm(VectorRef input, VarPtr gamma, VarPtr eps) const; 63 64 protected: 65 const std::string kPatternDecoderLayerPre = "PatternDecoderLayerPre"; 66 const std::string kPatternDecoderLayerPost = "PatternDecoderLayerPost"; 67 const std::string kPatternDecoderLayerNormPre = "PatternDecoderLayerNormPre"; 68 const std::string kPatternDecoderLayerNormPost = "PatternDecoderLayerNormPost"; 69 const std::string kPatternDecoderLayerNormT5Pre = "PatternDecoderLayerNormT5Pre"; 70 const std::string kPatternDecoderT5Pre = "PatternDecoderT5Pre"; 71 const std::string kPatternDecoderT5Post = "PatternDecoderT5Post"; 72 mutable VarPtr hidden_stats_{nullptr}; 73 mutable VarPtr encoder_output_{nullptr}; 74 mutable VarPtr position_bias_{nullptr}; 75 mutable VarPtr beta1_{nullptr}; 76 mutable VarPtr gamma1_{nullptr}; 77 mutable VarPtr beta2_{nullptr}; 78 mutable VarPtr gamma2_{nullptr}; 79 mutable VarPtr gamma3_{nullptr}; 80 mutable VarPtr gamma4_{nullptr}; 81 mutable VarPtr beta3_{nullptr}; 82 mutable VarPtr beta4_{nullptr}; 83 mutable VarPtr weight_attn_qkv_{nullptr}; 84 mutable VarPtr weight_attn_qkv_cross_{nullptr}; 85 mutable VarPtr weight_attn_o_{nullptr}; 86 mutable VarPtr weight_m_{nullptr}; 87 mutable VarPtr weight_p_{nullptr}; 88 mutable VarPtr bias_attn_qkv_{nullptr}; 89 mutable VarPtr bias_attn_o_{nullptr}; 90 mutable VarPtr bias_attn_cross_qkv_{nullptr}; 91 mutable VarPtr bias_attn_cross_o_{nullptr}; 92 mutable VarPtr bias_m_{nullptr}; 93 mutable VarPtr bias_p_{nullptr}; 94 mutable VarPtr mask_{nullptr}; 95 mutable VarPtr is_attention_{nullptr}; 96 mutable VarPtr is_attention_cross_{nullptr}; 97 mutable VarPtr weight_attn_q_{nullptr}; 98 mutable VarPtr weight_attn_kv_{nullptr}; 99 mutable VarPtr weight_attn_cross_o_{nullptr}; 100 mutable VarPtr position_bias_cross_{nullptr}; 101 mutable VarPtr cross_mask_{nullptr}; 102 mutable VarPtr reshape_k_{nullptr}; 103 mutable VarPtr reshape_v_{nullptr}; 104 mutable VarPtr is_layernorm1_{nullptr}; 105 mutable VarPtr is_layernorm2_{nullptr}; 106 mutable VarPtr is_layernorm3_{nullptr}; 107 mutable VarPtr is_act_{nullptr}; 108 mutable VarPtr eps1_{nullptr}; 109 mutable VarPtr eps2_{nullptr}; 110 mutable VarPtr eps3_{nullptr}; 111 mutable VarPtr eps4_{nullptr}; 112 mutable bool is_position_bias_{false}; 113 mutable bool is_layernorm_fusion_{false}; 114 mutable bool is_layernorm_{false}; 115 mutable ActType act_type_{ActType::ActType_No}; 116 mutable bool layer_norm_; 117 }; 118 } // namespace opt 119 } // namespace mindspore 120 #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_DECODER_LAYER_FUSION_H_ 121