• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_
17 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_
18 
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 #include "tools/optimizer/common/multiple_pattern_process_pass.h"
24 #include "include/common/utils/utils.h"
25 #include "include/errorcode.h"
26 #include "ops/encoder_layer.h"
27 #include "tools/optimizer/fusion/multi_head_attention_fusion.h"
28 #include "ops/fusion/layer_norm_fusion.h"
29 #include "ops/fusion/activation.h"
30 
31 namespace mindspore {
32 namespace opt {
33 class EncoderLayerFusion : public MultiplePatternProcessPass {
34  public:
35   explicit EncoderLayerFusion(bool embedding_layer = false, const std::string &name = "EncoderLayerFusion",
36                               bool multigraph = true)
MultiplePatternProcessPass(name,multigraph)37       : MultiplePatternProcessPass(name, multigraph) {
38     embedding_layer_ = embedding_layer;
39   }
40 
41   ~EncoderLayerFusion() override = default;
42 
43   AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
44                      const EquivPtr &) const override;
45   std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
46 
47  protected:
48   virtual bool Init() const;
49 
50  private:
51   const std::string kPatternEncoderLayerPreNormUsePast = "PatternEncoderLayerPreNormUsePast";
52   const std::string kPatternEncoderLayerUsePastWithLastNorm = "PatternEncoderLayerPreNormUsePastWithLastNorm";
53   const std::string kPatternEncoderLayerPost = "PatternTEncoderLayerPost";
54   const std::string kPatternEncoderLayerPre = "PatternTEncoderLayerPre";
55   const std::string kPatternEncoderLayerPostNorm = "PatternTEncoderLayerPostNorm";
56   const std::string kPatternEncoderLayerPreNorm = "PatternTEncoderLayerPreNorm";
57   const std::string kPatternEncoderLayerT5Post = "PatternEncoderLayerT5Post";
58   const std::string kPatternEncoderLayerT5Pre = "PatternEncoderLayerT5Pre";
59   const std::string kPatternEncoderLayerNormT5Pre = "PatternEncoderLayerNormT5Pre";
60   const std::string kPatternQueryLayerUsePast = "PatternQueryLayerUsePast";
61   const std::string kPatternSigmaDistributed = "PatternSigmaDistributed";
62   const std::string kPatternSigmaDistributedEmbedding = "PatternSigmaDistributedEmbedding";
63   const std::string kPatternSigmaMoeDistributed = "PatternSigmaMoeDistributed";
64   const std::string kPatternSigmaMoeWithLastLayerNormDistributed = "PatternSigmaMoeWithLastLayerNormDistributed";
65   const std::string kPatternSigmaWithLastLayerNormDistributed = "PatternSigmaWithLastLayerNormDistributed";
66   const std::string kPatternSigmaQueryLayerDistributed = "PatternSigmaQueryLayerDistributed";
67   const std::string kPatternDistributedAlpha = "PatternDistributedAlpha";
68   const std::string kPatternDistributedAlphaWithLastLayerNorm = "PatternDistributedAlphaWithLastLayerNorm";
69   const std::string kPatternQueryLayerUsePastDistributed = "PatternQueryLayerUsePastDistributed";
70   const std::string kPatternSigma = "kPatternSigma";
71   const std::string kPatternSigmaEmbedding = "kPatternSigmaEmbedding";
72   const std::string kPatternSigmaQuery = "kPatternSigmaQuery";
73   const std::string kPatternSigmaMoe = "kPatternSigmaMoe";
74   const std::string kPatternSigmaMoeWithLastLayerNorm = "PatternSigmaMoeWithLastLayerNorm";
75   const std::string kPatternSigmaWithLastLayerNorm = "PatternSigmaWithLastLayerNorm";
76   const std::string kPatternSigmaQueryLayerMoe = "PatternSigmaQueryLayerMoe";
77   const std::string kPatternSigmaDistributedMB = "PatternSigmaDistributedMB";
78   const std::string kPatternSigmaDistributedEmbeddingMB = "PatternSigmaDistributedEmbeddingMB";
79   const std::string kPatternSigmaMoeWithLastLayerNormDistributedMB = "PatternSigmaMoeWithLastLayerNormDistributedMB";
80   const std::string kPatternSigmaWithLastLayerNormDistributedMB = "PatternSigmaWithLastLayerNormMB";
81   const std::string kPatternSigmaQueryLayerDistributedMB = "PatternSigmaQueryLayerDistributedMB";
82   const std::string kPatternSigmaMoeDistributedMB = "PatternSigmaMoeDistributedMB";
83   const std::string kPatternSigmaDistributedMBGELU = "PatternSigmaDistributedMBGELU";
84   const std::string kPatternSigmaDistributedEmbeddingMBGELU = "PatternSigmaDistributedEmbeddingMBGELU";
85   const std::string kPatternSigmaMoeWithLastLayerNormDistributedMBGELU =
86     "PatternSigmaMoeWithLastLayerNormDistributedMBGELU";
87   const std::string kPatternSigmaWithLastLayerNormDistributedMBGELU = "PatternSigmaWithLastLayerNormMBGELU";
88   const std::string kPatternSigmaQueryLayerDistributedMBGELU = "PatternSigmaQueryLayerDistributedMBGELU";
89   const std::string kPatternSigmaMoeDistributedMBGELU = "PatternSigmaMoeDistributedMBGELU";
90   const std::string kPatternSigmaDistributedMBFirst = "PatternSigmaMoeDistributedMBFirst";
91   const std::string kPatternSigmaFirst = "kPatternSigmaFirsts";
92   const std::string kPatternSigmaQueryLayerDistributedMBMoe = "kPatternSigmaQueryLayerDistributedMBMoe";
93   const std::string kPatternSigmaQueryLayerDistributedMoe = "kPatternSigmaQueryLayerDistributedMoe";
94   const std::string kPatternSigmaEmbeddingDistributed = "kPatternSigmaEmbeddingDistributed";
95 
96   VectorRef DefinePatternEncoderLayer(bool post_layernorm, bool layernorm_fusion, bool is_position_bias_, bool mask,
97                                       bool is_layer_norm) const;
98   VectorRef DefinePatternEncoderSigma(bool moe, bool use_past, bool distributed, bool is_layer_norm, bool query_layer,
99                                       bool multi_batch, bool first_encoder, bool gelu) const;
100 
101   VectorRef DefinePatternEncoderAlpha(bool moe, bool distributed, bool is_layer_norm, bool query_layer,
102                                       bool use_past) const;
103   VectorRef getTuple(bool post_layernorm, bool layernorm_fusion, bool is_position_bias) const;
104   VectorRef DefineLayerNorm(bool is_position_bias, BaseRef input, VarPtr gamma, VarPtr beta, VarPtr eps) const;
105   CNodePtr CreateMaskedEncoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
106                                               const AnfNodePtr &node, bool post_layernorm, bool mask) const;
107   AnfNodePtr GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv, VarPtr node_name) const;
108   bool IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const VarPtr &input_prim) const;
109   lite::STATUS GetEps(const EquivPtr &equiv, VarPtr node_name, float *eps) const;
110   lite::STATUS CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num, int *head_size,
111                             float *eps1, float *eps2, float *eps3, float *scale) const;
112   std::shared_ptr<ops::EncoderLayer> CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
113                                                 int64_t ffn_hidden_size, int64_t expert_num, int64_t expert_offset,
114                                                 float capacity_factor) const;
115   VectorRef DefinePatternInitReset(VectorRef input, bool is_value_reset = false, bool is_key_reset = false) const;
116   VectorRef DefinePatternMultiBatch(VectorRef input) const;
117   BaseRef DefineBatchValidLength(const BaseRef &input) const;
118   VectorRef DefinePatternMoERouter(VectorRef input_layernorm) const;
119   VectorRef DefinePatternMoE(VectorRef input_layernorm, bool multi_batch, bool gelu) const;
120   VectorRef DefinePatternSigmaFfn(BaseRef input, bool gelu, bool distributed) const;
121   VectorRef DefinePatternMoETopKRouter(VectorRef input) const;
122   VectorRef DefinePatternMoEFfn(VectorRef input_reshape, bool gelu) const;
123   VectorRef DefineDependKV(VectorRef input_layernorm, VectorRef deppend_v_input, bool moe) const;
124   VectorRef DefineFfn(VectorRef input) const;
125   VectorRef DefineFirstEncoder(bool distributed) const;
126   lite::STATUS InitAttributes(AnfNodePtr k_past, AnfNodePtr begin_expert_ids, AnfNodePtr weight_m,
127                               AnfNodePtr expert_capacity_node, int *ffn_hidden_size, int *expert_num,
128                               int *expert_offset, float *capacity_factor) const;
129   void InitParams(bool post_layernorm, bool layernorm_fusion, bool is_position_bias, bool mask, bool is_layer_norm,
130                   bool use_past, bool query_layer, bool sigma, bool distributed, bool moe) const;
131   bool IsUsePast(const std::string &pattern_name) const;
132   bool IsUsePastMB(const std::string &pattern_name) const;
133   bool IsUsePastAlpha(const std::string &pattern_name) const;
134   bool IsLastLayerNorm(const std::string &pattern_name) const;
135   bool IsLayerNormFusion(const std::string &pattern_name) const;
136   bool IsMoe(const std::string &pattern_name) const;
137   bool IsFastGelu(const std::string &pattern_name) const;
138   bool IsFastGeluDistributed(const std::string &pattern_name) const;
139   bool IsQueryLayer(const std::string &pattern_name) const;
140 
141  protected:
142   mutable VarPtr input_{nullptr};
143   mutable VarPtr expert_ids_input_{nullptr};
144   mutable VarPtr expert_ids_{nullptr};
145   mutable VarPtr expert_capacity_{nullptr};
146   mutable VarPtr begin_expert_ids_{nullptr};
147   mutable VarPtr position_bias_{nullptr};
148   mutable VarPtr beta1_{nullptr};
149   mutable VarPtr gamma1_{nullptr};
150   mutable VarPtr beta2_{nullptr};
151   mutable VarPtr gamma2_{nullptr};
152   mutable VarPtr beta3_{nullptr};
153   mutable VarPtr gamma3_{nullptr};
154   mutable VarPtr weight_attn_qkv_{nullptr};
155   mutable VarPtr weight_attn_q_{nullptr};
156   mutable VarPtr weight_attn_o_{nullptr};
157   mutable VarPtr weight_m_{nullptr};
158   mutable VarPtr weight_p_{nullptr};
159   mutable VarPtr bias_attn_qkv_{nullptr};
160   mutable VarPtr bias_attn_o_{nullptr};
161   mutable VarPtr bias_m_{nullptr};
162   mutable VarPtr bias_p_{nullptr};
163   mutable VarPtr mask_{nullptr};
164   mutable VarPtr is_attention_{nullptr};
165   mutable VarPtr is_layernorm1_{nullptr};
166   mutable VarPtr is_layernorm2_{nullptr};
167   mutable VarPtr is_layernorm3_{nullptr};
168   mutable ActType act_type_{ActType::ActType_No};
169   mutable VarPtr is_act_{nullptr};
170   mutable VarPtr eps1_{nullptr};
171   mutable VarPtr eps2_{nullptr};
172   mutable VarPtr eps3_{nullptr};
173   mutable VarPtr init_reset_{nullptr};
174   mutable VarPtr k_past_{nullptr};
175   mutable VarPtr v_past_{nullptr};
176   mutable VarPtr input_q_{nullptr};
177   mutable VarPtr batch_valid_length_{nullptr};
178   mutable VarPtr embedding_table_{nullptr};
179   mutable VarPtr weight_router_{nullptr};
180 
181   mutable VarPtr position_ids_{nullptr};
182   mutable VarPtr embedding_table_input_{nullptr};
183   mutable VarPtr current_index_{nullptr};
184   mutable VarPtr embedding_table_pos_{nullptr};
185 
186   mutable bool is_position_bias_{false};
187   mutable bool is_post_layernorm_{false};
188   mutable bool is_layernorm_fusion_{false};
189   mutable bool is_layernorm_{false};
190   mutable bool is_use_past_{false};
191   mutable bool is_query_layer_{false};
192   mutable bool is_sigma_{false};
193   mutable bool is_moe_{false};
194   mutable bool is_distributed_{false};
195   mutable bool is_fast_gelu_{false};
196   mutable bool is_embedding_layer_{false};
197 
198   mutable bool embedding_layer_{false};
199 };
200 }  // namespace opt
201 }  // namespace mindspore
202 #endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ENCODER_LAYER_FUSION_H_
203