• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // /**
2 //  * Copyright 2023 Huawei Technologies Co., Ltd
3 //  *
4 //  * Licensed under the Apache License, Version 2.0 (the "License");
5 //  * you may not use this file except in compliance with the License.
6 //  * You may obtain a copy of the License at
7 //  *
8 //  * http://www.apache.org/licenses/LICENSE-2.0
9 //  *
10 //  * Unless required by applicable law or agreed to in writing, software
11 //  * distributed under the License is distributed on an "AS IS" BASIS,
12 //  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 //  * See the License for the specific language governing permissions and
14 //  * limitations under the License.
15 //  */
16 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_DECODER_LAYER_FUSION_H_
17 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_DECODER_LAYER_FUSION_H_
18 
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23 #include "tools/optimizer/common/multiple_pattern_process_pass.h"
24 #include "include/common/utils/utils.h"
25 #include "include/errorcode.h"
26 #include "ops/decoder_layer.h"
27 #include "ops/fusion/layer_norm_fusion.h"
28 #include "ops/fusion/activation.h"
29 #include "tools/optimizer/fusion/multi_head_attention_fusion.h"
30 
31 namespace mindspore {
32 namespace opt {
33 class DecoderLayerFusion : public MultiplePatternProcessPass {
34  public:
35   explicit DecoderLayerFusion(const std::string &name = "DecoderLayerFusion", bool multigraph = true)
MultiplePatternProcessPass(name,multigraph)36       : MultiplePatternProcessPass(name, multigraph) {}
37 
38   ~DecoderLayerFusion() override = default;
39 
40   AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &, const AnfNodePtr &,
41                      const EquivPtr &) const override;
42   std::unordered_map<std::string, VectorRef> DefinePatterns() const override;
43 
44  protected:
45   virtual bool Init() const;
46 
47  private:
48   VectorRef DefinePatternDecoderLayer(bool post_layernorm, bool layernorm_fusion, bool is_position_bias, bool mask,
49                                       bool is_layer_norm) const;
50   VectorRef getTuple(bool post_layernorm, bool layernorm_fusion, bool is_position_bias) const;
51   VectorRef DefineLayerNorm(VectorRef input, VarPtr gamma, VarPtr beta, VarPtr eps) const;
52   CNodePtr CreateMaskedDecoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
53                                               const AnfNodePtr &node, bool post_layernorm, bool mask) const;
54   std::shared_ptr<ops::DecoderLayer> CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
55                                                 bool post_layernorm, int64_t ffn_hidden_size) const;
56   lite::STATUS CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num, int *head_size,
57                             float *eps1, float *eps2, float *eps3, float *eps4, bool *is_position_bias1,
58                             bool *is_position_bias2, float *scale1, float *scale2) const;
59   AnfNodePtr GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv, VarPtr node_name) const;
60   bool IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const;
61   lite::STATUS GetEps(const EquivPtr &equiv, VarPtr node_name, float *eps) const;
62   VectorRef DefineDecoderLayerNorm(VectorRef input, VarPtr gamma, VarPtr eps) const;
63 
64  protected:
65   const std::string kPatternDecoderLayerPre = "PatternDecoderLayerPre";
66   const std::string kPatternDecoderLayerPost = "PatternDecoderLayerPost";
67   const std::string kPatternDecoderLayerNormPre = "PatternDecoderLayerNormPre";
68   const std::string kPatternDecoderLayerNormPost = "PatternDecoderLayerNormPost";
69   const std::string kPatternDecoderLayerNormT5Pre = "PatternDecoderLayerNormT5Pre";
70   const std::string kPatternDecoderT5Pre = "PatternDecoderT5Pre";
71   const std::string kPatternDecoderT5Post = "PatternDecoderT5Post";
72   mutable VarPtr hidden_stats_{nullptr};
73   mutable VarPtr encoder_output_{nullptr};
74   mutable VarPtr position_bias_{nullptr};
75   mutable VarPtr beta1_{nullptr};
76   mutable VarPtr gamma1_{nullptr};
77   mutable VarPtr beta2_{nullptr};
78   mutable VarPtr gamma2_{nullptr};
79   mutable VarPtr gamma3_{nullptr};
80   mutable VarPtr gamma4_{nullptr};
81   mutable VarPtr beta3_{nullptr};
82   mutable VarPtr beta4_{nullptr};
83   mutable VarPtr weight_attn_qkv_{nullptr};
84   mutable VarPtr weight_attn_qkv_cross_{nullptr};
85   mutable VarPtr weight_attn_o_{nullptr};
86   mutable VarPtr weight_m_{nullptr};
87   mutable VarPtr weight_p_{nullptr};
88   mutable VarPtr bias_attn_qkv_{nullptr};
89   mutable VarPtr bias_attn_o_{nullptr};
90   mutable VarPtr bias_attn_cross_qkv_{nullptr};
91   mutable VarPtr bias_attn_cross_o_{nullptr};
92   mutable VarPtr bias_m_{nullptr};
93   mutable VarPtr bias_p_{nullptr};
94   mutable VarPtr mask_{nullptr};
95   mutable VarPtr is_attention_{nullptr};
96   mutable VarPtr is_attention_cross_{nullptr};
97   mutable VarPtr weight_attn_q_{nullptr};
98   mutable VarPtr weight_attn_kv_{nullptr};
99   mutable VarPtr weight_attn_cross_o_{nullptr};
100   mutable VarPtr position_bias_cross_{nullptr};
101   mutable VarPtr cross_mask_{nullptr};
102   mutable VarPtr reshape_k_{nullptr};
103   mutable VarPtr reshape_v_{nullptr};
104   mutable VarPtr is_layernorm1_{nullptr};
105   mutable VarPtr is_layernorm2_{nullptr};
106   mutable VarPtr is_layernorm3_{nullptr};
107   mutable VarPtr is_act_{nullptr};
108   mutable VarPtr eps1_{nullptr};
109   mutable VarPtr eps2_{nullptr};
110   mutable VarPtr eps3_{nullptr};
111   mutable VarPtr eps4_{nullptr};
112   mutable bool is_position_bias_{false};
113   mutable bool is_layernorm_fusion_{false};
114   mutable bool is_layernorm_{false};
115   mutable ActType act_type_{ActType::ActType_No};
116   mutable bool layer_norm_;
117 };
118 }  // namespace opt
119 }  // namespace mindspore
120 #endif  // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_DECODER_LAYER_FUSION_H_
121