• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/decoder_layer_fusion.h"
19 #include <functional>
20 #include <utility>
21 #include <vector>
22 #include <algorithm>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/nn_ops.h"
25 #include "mindspore/core/ops/math_ops.h"
26 #include "mindspore/core/ops/lite_ops.h"
27 #include "mindspore/core/ops/array_ops.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "nnacl/op_base.h"
30 #include "ops/tuple_get_item.h"
31 #include "tools/common/tensor_util.h"
32 #include "ops/op_utils.h"
33 
34 namespace mindspore::opt {
35 namespace {
36 const auto &p1 = std::placeholders::_1;
37 }  // namespace
38 
Init() const39 bool DecoderLayerFusion::Init() const {
40   hidden_stats_ = std::make_shared<Var>("input");
41   MS_CHECK_TRUE_RET(hidden_stats_ != nullptr, false);
42   encoder_output_ = std::make_shared<Var>("input");
43   MS_CHECK_TRUE_RET(encoder_output_ != nullptr, false);
44   beta1_ = std::make_shared<Var>("beta1");
45   MS_CHECK_TRUE_RET(beta1_ != nullptr, false);
46   gamma1_ = std::make_shared<Var>("gamma1");
47   MS_CHECK_TRUE_RET(gamma1_ != nullptr, false);
48   beta2_ = std::make_shared<Var>("beta2");
49   MS_CHECK_TRUE_RET(beta2_ != nullptr, false);
50   gamma2_ = std::make_shared<Var>("gamma2");
51   MS_CHECK_TRUE_RET(gamma2_ != nullptr, false);
52   beta3_ = std::make_shared<Var>("beta3");
53   MS_CHECK_TRUE_RET(beta3_ != nullptr, false);
54   gamma3_ = std::make_shared<Var>("gamma3");
55   MS_CHECK_TRUE_RET(gamma3_ != nullptr, false);
56   gamma4_ = std::make_shared<Var>("gamma4");
57   MS_CHECK_TRUE_RET(gamma4_ != nullptr, false);
58   beta4_ = std::make_shared<Var>("beta4");
59   MS_CHECK_TRUE_RET(beta4_ != nullptr, false);
60   weight_attn_qkv_ = std::make_shared<Var>("weight_attn_qkv");
61   MS_CHECK_TRUE_RET(weight_attn_qkv_ != nullptr, false);
62   weight_attn_q_ = std::make_shared<Var>("weight_attn_q_");
63   MS_CHECK_TRUE_RET(weight_attn_q_ != nullptr, false);
64   weight_attn_kv_ = std::make_shared<Var>("weight_attn_kv_");
65   MS_CHECK_TRUE_RET(weight_attn_kv_ != nullptr, false);
66   weight_attn_o_ = std::make_shared<CondVar>(IsParamNode, "weight_attn_o");
67   MS_CHECK_TRUE_RET(weight_attn_o_ != nullptr, false);
68   weight_attn_cross_o_ = std::make_shared<CondVar>(IsParamNode, "weight_attn_cross_o_");
69   MS_CHECK_TRUE_RET(weight_attn_cross_o_ != nullptr, false);
70   weight_m_ = std::make_shared<CondVar>(IsParamNode, "weight_m");
71   MS_CHECK_TRUE_RET(weight_m_ != nullptr, false);
72   weight_p_ = std::make_shared<CondVar>(IsParamNode, "weight_p");
73   MS_CHECK_TRUE_RET(weight_p_ != nullptr, false);
74   bias_attn_qkv_ = std::make_shared<Var>("bias_attn_qkv");
75   MS_CHECK_TRUE_RET(bias_attn_qkv_ != nullptr, false);
76   bias_attn_o_ = std::make_shared<CondVar>(IsParamNode, "bias_attn_o");
77   MS_CHECK_TRUE_RET(bias_attn_o_ != nullptr, false);
78   bias_attn_cross_qkv_ = std::make_shared<Var>("bias_attn_cross_qkv_");
79   MS_CHECK_TRUE_RET(bias_attn_cross_qkv_ != nullptr, false);
80   bias_attn_cross_o_ = std::make_shared<CondVar>(IsParamNode, "bias_attn_cross_o_");
81   MS_CHECK_TRUE_RET(bias_attn_cross_o_ != nullptr, false);
82   bias_m_ = std::make_shared<CondVar>(IsParamNode, "bias_m");
83   MS_CHECK_TRUE_RET(bias_m_ != nullptr, false);
84   bias_p_ = std::make_shared<CondVar>(IsParamNode, "bias_p");
85   MS_CHECK_TRUE_RET(bias_p_ != nullptr, false);
86   mask_ = std::make_shared<Var>("mask");
87   MS_CHECK_TRUE_RET(mask_ != nullptr, false);
88   cross_mask_ = std::make_shared<Var>("cross_mask_");
89   MS_CHECK_TRUE_RET(cross_mask_ != nullptr, false);
90   is_attention_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAttention), "is_attention");
91   MS_CHECK_TRUE_RET(is_attention_ != nullptr, false);
92   is_attention_cross_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAttention), "is_attention_cross");
93   MS_CHECK_TRUE_RET(is_attention_cross_ != nullptr, false);
94   is_layernorm1_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm1");
95   MS_CHECK_TRUE_RET(is_layernorm1_ != nullptr, false);
96   is_layernorm2_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm2");
97   MS_CHECK_TRUE_RET(is_layernorm2_ != nullptr, false);
98   is_layernorm3_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm3");
99   MS_CHECK_TRUE_RET(is_layernorm3_ != nullptr, false);
100   position_bias_ = std::make_shared<Var>("position_bias");
101   MS_CHECK_TRUE_RET(position_bias_ != nullptr, false);
102   position_bias_cross_ = std::make_shared<Var>("position_bias_cross_");
103   MS_CHECK_TRUE_RET(position_bias_ != nullptr, false);
104   is_act_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "activation");
105   MS_CHECK_TRUE_RET(is_act_ != nullptr, false);
106   eps1_ = std::make_shared<Var>("eps1_");
107   MS_CHECK_TRUE_RET(eps1_ != nullptr, false);
108   eps2_ = std::make_shared<Var>("eps2_");
109   MS_CHECK_TRUE_RET(eps2_ != nullptr, false);
110   eps3_ = std::make_shared<Var>("eps3_");
111   MS_CHECK_TRUE_RET(eps3_ != nullptr, false);
112   eps4_ = std::make_shared<Var>("eps4_");
113   MS_CHECK_TRUE_RET(eps4_ != nullptr, false);
114   return true;
115 }
116 
getTuple(bool post_layernorm,bool layernorm_fusion=false,bool is_position_bias=false) const117 VectorRef DecoderLayerFusion::getTuple(bool post_layernorm, bool layernorm_fusion = false,
118                                        bool is_position_bias = false) const {
119   auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-decoder");
120   MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
121   auto var1 = std::make_shared<Var>("var1-reshape");
122   MS_CHECK_TRUE_RET(var1 != nullptr, {});
123   auto reshape1 = VectorRef({is_reshape1, hidden_stats_, var1});
124   VectorRef layer_norm, tuple;
125   if (!layernorm_fusion) {
126     return DefineLayerNorm(reshape1, gamma1_, beta1_, eps1_);
127   }
128   layer_norm = VectorRef({is_layernorm1_, reshape1, gamma1_, beta1_});
129   auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
130   auto var_tuple = std::make_shared<Var>("var_tuple");
131   tuple = VectorRef({is_tuple, layer_norm, var_tuple});
132   return tuple;
133 }
134 
DefineLayerNorm(VectorRef input,VarPtr gamma,VarPtr beta,VarPtr eps) const135 VectorRef DecoderLayerFusion::DefineLayerNorm(VectorRef input, VarPtr gamma, VarPtr beta, VarPtr eps) const {
136   auto is_sqr = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSquare), "sqr2");
137   MS_CHECK_TRUE_RET(is_sqr != nullptr, {});
138   auto sqr = VectorRef({is_sqr, input});
139   auto var1 = std::make_shared<Var>("var1");
140   MS_CHECK_TRUE_RET(var1 != nullptr, {});
141   auto is_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "reduce");
142   MS_CHECK_TRUE_RET(is_reduce != nullptr, {});
143   auto reduce = VectorRef({is_reduce, sqr, var1});
144   auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is-add");
145   MS_CHECK_TRUE_RET(is_add != nullptr, {});
146   auto add = VectorRef({is_add, reduce, eps});
147   auto is_sqrt = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSqrt), "sqr2");
148   MS_CHECK_TRUE_RET(is_sqrt != nullptr, {});
149   auto sqrt = VectorRef({is_sqrt, add});
150   auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "real-div");
151   MS_CHECK_TRUE_RET(is_div != nullptr, {});
152   auto real_div = VectorRef({is_div, input, sqrt});
153   auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "mul");
154   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
155   auto mul = VectorRef({is_mul, real_div, gamma});
156   return mul;
157 }
158 
DefinePatternDecoderLayer(bool post_layernorm=true,bool layernorm_fusion=false,bool is_position_bias=false,bool mask=true,bool is_layer_norm=false) const159 VectorRef DecoderLayerFusion::DefinePatternDecoderLayer(bool post_layernorm = true, bool layernorm_fusion = false,
160                                                         bool is_position_bias = false, bool mask = true,
161                                                         bool is_layer_norm = false) const {
162   auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-decoder");
163   MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
164   auto var1 = std::make_shared<Var>("var1-reshape");
165   MS_CHECK_TRUE_RET(var1 != nullptr, {});
166   auto reshape1 = VectorRef({is_reshape1, hidden_stats_, var1});
167   VectorRef inputs, input_cross, tuple2, tuple3, matmul2, tuple4, tuple5;
168   if (is_position_bias) {
169     inputs = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
170                         getTuple(post_layernorm, layernorm_fusion, is_position_bias),
171                         getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_, weight_attn_o_,
172                         position_bias_});
173   } else {
174     inputs = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
175                         getTuple(post_layernorm, layernorm_fusion, is_position_bias),
176                         getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_, weight_attn_o_,
177                         bias_attn_qkv_, bias_attn_o_});
178   }
179   if (mask) inputs.push_back(mask_);
180   auto attention = VectorRef(inputs);
181   if (is_position_bias) {
182     tuple4 = attention;
183   } else {
184     auto is_tuple4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item4");
185     auto var_tuple4 = std::make_shared<Var>("var_tuple4");
186     tuple4 = VectorRef({is_tuple4, attention, var_tuple4});
187   }
188   auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add2");
189   auto add2 = (post_layernorm)
190                 ? VectorRef({is_add2, getTuple(post_layernorm, layernorm_fusion, is_position_bias), tuple4})
191                 : VectorRef({is_add2, reshape1, tuple4});
192   if (layernorm_fusion) {
193     auto layer_norm2 = VectorRef({is_layernorm2_, add2, gamma2_, beta2_});
194     auto is_tuple2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item2");
195     auto var_tuple2 = std::make_shared<Var>("var_tuple2");
196     tuple2 = VectorRef({is_tuple2, layer_norm2, var_tuple2});
197   } else {
198     tuple2 = DefineLayerNorm(add2, gamma2_, beta2_, eps2_);
199   }
200   auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-decoder2");
201   MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
202   auto var2 = std::make_shared<Var>("var2");
203   MS_CHECK_TRUE_RET(var2 != nullptr, {});
204   auto reshape2 = VectorRef({is_reshape2, encoder_output_, var2});
205   if (is_position_bias) {
206     input_cross = VectorRef({is_attention_cross_, tuple2, reshape2, reshape2, weight_attn_q_, weight_attn_kv_,
207                              weight_attn_cross_o_, position_bias_cross_});
208   } else {
209     input_cross = VectorRef({is_attention_cross_, tuple2, reshape2, reshape2, weight_attn_q_, weight_attn_kv_,
210                              weight_attn_cross_o_, bias_attn_cross_qkv_, bias_attn_cross_o_});
211   }
212   if (mask) input_cross.push_back(cross_mask_);
213   auto attention_cross = VectorRef(input_cross);
214   if (is_position_bias) {
215     tuple5 = attention_cross;
216   } else {
217     auto is_tuple5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item5");
218     auto var_tuple5 = std::make_shared<Var>("var_tuple5");
219     tuple5 = VectorRef({is_tuple5, attention_cross, var_tuple5});
220   }
221   auto is_add3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add3");
222   MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
223   auto add3 = (post_layernorm) ? VectorRef({is_add3, tuple2, tuple5}) : VectorRef({is_add3, add2, tuple5});
224   if (layernorm_fusion) {
225     auto layer_norm3 = VectorRef({is_layernorm3_, add3, gamma3_, beta3_});
226     auto is_tuple3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item3");
227     auto var_tuple3 = std::make_shared<Var>("var_tuple3");
228     tuple3 = VectorRef({is_tuple3, layer_norm3, var_tuple3});
229   } else {
230     tuple3 = DefineLayerNorm(add3, gamma3_, beta3_, eps3_);
231   }
232   auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
233   MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
234   auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
235   MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
236   if (!is_position_bias) {
237     auto matmul1 = VectorRef({is_matmul1, tuple3, weight_m_, bias_m_});
238     auto act = VectorRef({is_act_, matmul1});
239     matmul2 = VectorRef({is_matmul2, act, weight_p_, bias_p_});
240   } else {
241     auto matmul1 = VectorRef({is_matmul1, tuple3, weight_m_});
242     matmul2 = VectorRef({is_matmul2, matmul1, weight_p_});
243   }
244   auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-decoder3");
245   MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
246   auto var3 = std::make_shared<Var>("var3");
247   MS_CHECK_TRUE_RET(var3 != nullptr, {});
248   auto reshape3 = VectorRef({is_reshape3, matmul2, var3});
249   auto is_reshape4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-decoder4");
250   MS_CHECK_TRUE_RET(is_reshape4 != nullptr, {});
251   auto var4 = std::make_shared<Var>("var4");
252   MS_CHECK_TRUE_RET(var4 != nullptr, {});
253   auto reshape4 = (post_layernorm) ? VectorRef({is_reshape4, tuple3, var4}) : VectorRef({is_reshape4, add3, var4});
254   auto is_add4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add4");
255   auto add4 = VectorRef({is_add4, reshape4, reshape3});
256   return (is_layer_norm) ? DefineLayerNorm(add4, gamma4_, beta4_, eps4_) : add4;
257 }
258 
DefinePatterns() const259 std::unordered_map<std::string, VectorRef> DecoderLayerFusion::DefinePatterns() const {
260   std::unordered_map<std::string, VectorRef> patterns;
261   if (!Init()) {
262     MS_LOG(ERROR) << "initial member failed.";
263     return patterns;
264   }
265   patterns[kPatternDecoderLayerNormT5Pre] = DefinePatternDecoderLayer(false, false, true, true, true);
266   patterns[kPatternDecoderLayerPre] = DefinePatternDecoderLayer(false, true, false, true);
267   patterns[kPatternDecoderLayerPost] = DefinePatternDecoderLayer(true, true, false, true);
268   patterns[kPatternDecoderLayerNormPre] = DefinePatternDecoderLayer(false, false, false, true);
269   patterns[kPatternDecoderLayerNormPost] = DefinePatternDecoderLayer(true, false, false, true);
270   patterns[kPatternDecoderT5Pre] = DefinePatternDecoderLayer(false, false, true, true);
271   patterns[kPatternDecoderT5Post] = DefinePatternDecoderLayer(true, false, true, true);
272   return patterns;
273 }
274 
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const275 AnfNodePtr DecoderLayerFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
276                                        const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
277   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
278     return nullptr;
279   }
280   if (pattern_name == kPatternDecoderT5Pre || pattern_name == kPatternDecoderT5Post ||
281       pattern_name == kPatternDecoderLayerNormT5Pre) {
282     is_position_bias_ = true;
283   }
284   is_layernorm_ = false;
285   if (pattern_name == kPatternDecoderLayerNormT5Pre) {
286     is_layernorm_ = true;
287   }
288   if (pattern_name == kPatternDecoderLayerPre || pattern_name == kPatternDecoderLayerPost) {
289     is_layernorm_fusion_ = true;
290   }
291   bool mask = true;
292   bool post_layernorm = false;
293   if (pattern_name == kPatternDecoderLayerPost || pattern_name == kPatternDecoderT5Post ||
294       pattern_name == kPatternDecoderLayerNormPost) {
295     post_layernorm = true;
296   }
297   return CreateMaskedDecoderLayerFusionNode(func_graph, equiv, node, post_layernorm, mask);
298 }  // namespace mindspore::opt
299 
IsActGELU(const FuncGraphPtr & func_graph,const EquivPtr & equiv) const300 bool DecoderLayerFusion::IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const {
301   auto act_input = GetAttribute(func_graph, equiv, is_act_);
302   MS_ASSERT(act_input != nullptr);
303   auto act_primitive = ops::GetOperator<ops::Activation>(act_input);
304   MS_CHECK_TRUE_RET(act_primitive != nullptr, false);
305   auto act_primitive_c = act_primitive->GetPrim();
306   if (act_primitive_c->GetAttr(ops::kActivationType) == nullptr ||
307       act_primitive->get_activation_type() != mindspore::GELU) {
308     return false;
309   }
310   return true;
311 }
312 
GetAttribute(const FuncGraphPtr & func_graph,const EquivPtr & equiv,VarPtr node_name) const313 AnfNodePtr DecoderLayerFusion::GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
314                                             VarPtr node_name) const {
315   if ((*equiv)[node_name] == nullptr || !utils::isa<AnfNodePtr>((*equiv)[node_name])) {
316     MS_LOG(ERROR) << node_name << "is not AnfNodePtr";
317     return nullptr;
318   }
319   AnfNodePtr node = utils::cast<AnfNodePtr>((*equiv)[node_name]);
320   MS_ASSERT(node != nullptr);
321   if (node == nullptr || !utils::isa<CNodePtr>(node)) {
322     auto manager = func_graph->manager();
323     if (manager == nullptr) {
324       return nullptr;
325     }
326     auto users = manager->node_users();
327     auto it = users.find(node);
328     if (it != users.end()) {
329       node = it->second.front().first;
330     }
331     if (node == nullptr || !utils::isa<CNodePtr>(node)) {
332       return nullptr;
333     }
334   }
335   auto cnode = utils::cast<CNodePtr>(node);
336   MS_ASSERT(cnode != nullptr);
337   auto input = cnode->input(0);
338   return input;
339 }
340 
GetEps(const EquivPtr & equiv,VarPtr node_name,float * eps) const341 STATUS DecoderLayerFusion::GetEps(const EquivPtr &equiv, VarPtr node_name, float *eps) const {
342   if ((*equiv)[node_name] == nullptr || !utils::isa<AnfNodePtr>((*equiv)[node_name])) {
343     MS_LOG(ERROR) << node_name << " is not anfnodeptr";
344     return RET_ERROR;
345   }
346   AnfNodePtr node = utils::cast<AnfNodePtr>((*equiv)[node_name]);
347   MS_ASSERT(node != nullptr);
348   if (utils::isa<ValueNodePtr>(node)) {
349     auto value_ptr_node = utils::cast<ValueNodePtr>(node);
350     auto value_node = utils::cast<ValuePtr>(value_ptr_node->value());
351     if (value_node->isa<tensor::Tensor>()) {
352       auto tensor = value_node->cast<tensor::TensorPtr>();
353       MS_EXCEPTION_IF_NULL(tensor);
354       *eps = *reinterpret_cast<float *>(tensor->data().data());
355       return RET_OK;
356     }
357   }
358   return RET_ERROR;
359 }
360 
CheckPattern(const FuncGraphPtr & func_graph,const EquivPtr & equiv,int * head_num,int * head_size,float * eps1,float * eps2,float * eps3,float * eps4,bool * is_position_bias1,bool * is_position_bias2,float * scale1,float * scale2) const361 STATUS DecoderLayerFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num,
362                                         int *head_size, float *eps1, float *eps2, float *eps3, float *eps4,
363                                         bool *is_position_bias1, bool *is_position_bias2, float *scale1,
364                                         float *scale2) const {
365   auto attn_input = GetAttribute(func_graph, equiv, is_attention_);
366   MS_ASSERT(attn_input != nullptr);
367   auto attn_prim = ops::GetOperator<ops::Attention>(attn_input);
368   if (attn_prim->GetAttr(ops::kNumHeads) != nullptr) *head_num = attn_prim->get_head_num();
369   if (attn_prim->GetAttr(ops::kSizePerHead) != nullptr) *head_size = attn_prim->get_head_size();
370   if (attn_prim->GetAttr(ops::kPositionBias1) != nullptr) *is_position_bias1 = attn_prim->get_position_bias();
371   if (attn_prim->GetAttr(ops::kScale) != nullptr) *scale1 = attn_prim->get_scale();
372   auto attn_cross_input = GetAttribute(func_graph, equiv, is_attention_cross_);
373   MS_ASSERT(attn_cross_input != nullptr);
374   auto attn_cross_prim = ops::GetOperator<ops::Attention>(attn_cross_input);
375   if (attn_cross_prim->GetAttr(ops::kPositionBias1) != nullptr)
376     *is_position_bias2 = attn_cross_prim->get_position_bias();
377   if (attn_cross_prim->GetAttr(ops::kScale) != nullptr) *scale2 = attn_cross_prim->get_scale();
378   if (is_layernorm_fusion_) {
379     auto layrn1_input = GetAttribute(func_graph, equiv, is_layernorm1_);
380     auto layrn1_prim = ops::GetOperator<ops::LayerNormFusion>(layrn1_input);
381     if (layrn1_prim->GetAttr(ops::kEpsilon) != nullptr) *eps1 = layrn1_prim->get_epsilon();
382     auto layrn2_input = GetAttribute(func_graph, equiv, is_layernorm2_);
383     auto layrn2_prim = ops::GetOperator<ops::LayerNormFusion>(layrn2_input);
384     if (layrn2_prim->GetAttr(ops::kEpsilon) != nullptr) *eps2 = layrn2_prim->get_epsilon();
385     auto layrn3_input = GetAttribute(func_graph, equiv, is_layernorm3_);
386     auto layrn3_prim = ops::GetOperator<ops::LayerNormFusion>(layrn3_input);
387     if (layrn3_prim->GetAttr(ops::kEpsilon) != nullptr) *eps3 = layrn3_prim->get_epsilon();
388   } else {
389     MS_CHECK_TRUE_MSG(GetEps(equiv, eps1_, eps1) == RET_OK, RET_ERROR, "not found eps1");
390     MS_CHECK_TRUE_MSG(GetEps(equiv, eps2_, eps2) == RET_OK, RET_ERROR, "not found eps2");
391     MS_CHECK_TRUE_MSG(GetEps(equiv, eps3_, eps3) == RET_OK, RET_ERROR, "not found eps3");
392     if (is_layernorm_) {
393       MS_CHECK_TRUE_MSG(GetEps(equiv, eps4_, eps4) == RET_OK, RET_ERROR, "not found eps4");
394     }
395   }
396   if (!is_position_bias_) {
397     if (!IsActGELU(func_graph, equiv)) return RET_ERROR;
398     act_type_ = ActType::ActType_Gelu;
399   } else {
400     act_type_ = ActType::ActType_Relu;
401   }
402   return RET_OK;
403 }
404 
CreatePrim(const FuncGraphPtr & func_graph,const EquivPtr & equiv,bool post_layernorm,int64_t ffn_hidden_size) const405 std::shared_ptr<ops::DecoderLayer> DecoderLayerFusion::CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
406                                                                   bool post_layernorm, int64_t ffn_hidden_size) const {
407   auto decoder_layer_prim = std::make_shared<ops::DecoderLayer>();
408   if (decoder_layer_prim == nullptr) {
409     MS_LOG(ERROR) << "Build decoder layer primitive failed.";
410     return nullptr;
411   }
412   int head_num = 0;
413   int head_size = 0;
414   float eps1 = 1e-6;
415   float eps2 = 1e-6;
416   float eps3 = 1e-6;
417   float eps4 = 1e-6;
418   bool is_position_bias1 = false;
419   bool is_position_bias2 = false;
420   float scale1 = 1.0f;
421   float scale2 = 1.0f;
422   if (CheckPattern(func_graph, equiv, &head_num, &head_size, &eps1, &eps2, &eps3, &eps4, &is_position_bias1,
423                    &is_position_bias2, &scale1, &scale2)) {
424     return nullptr;
425   }
426   decoder_layer_prim->Init(head_num, head_size, eps1, eps2, eps3, eps4, ffn_hidden_size, is_position_bias1,
427                            is_position_bias2, post_layernorm, scale1, scale2, act_type_, is_layernorm_);
428   return decoder_layer_prim;
429 }
430 
CreateMaskedDecoderLayerFusionNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const AnfNodePtr & node,bool post_layernorm=true,bool mask=true) const431 CNodePtr DecoderLayerFusion::CreateMaskedDecoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
432                                                                 const AnfNodePtr &node, bool post_layernorm = true,
433                                                                 bool mask = true) const {
434   MS_ASSERT(func_graph != nullptr);
435   MS_ASSERT(equiv != nullptr);
436   MS_ASSERT(node != nullptr);
437   auto input = utils::cast<AnfNodePtr>((*equiv)[hidden_stats_]);
438   MS_ASSERT(input != nullptr);
439   auto encoder_output = utils::cast<AnfNodePtr>((*equiv)[encoder_output_]);
440   MS_ASSERT(encoder_output != nullptr);
441   AnfNodePtr position_bias, input_mask, bias_attn_o, bias_attn_qkv, beta1, beta2, bias_m, bias_p, beta3,
442     bias_attn_cross_qkv, bias_attn_cross_o, position_bias_cross, gamma4, beta4;
443   auto weight_qkv = utils::cast<AnfNodePtr>((*equiv)[weight_attn_qkv_]);
444   auto weight_attn_o = utils::cast<AnfNodePtr>((*equiv)[weight_attn_o_]);
445   auto weight_attn_q = utils::cast<AnfNodePtr>((*equiv)[weight_attn_q_]);
446   auto weight_attn_kv = utils::cast<AnfNodePtr>((*equiv)[weight_attn_kv_]);
447   auto weight_attn_cross_o = utils::cast<AnfNodePtr>((*equiv)[weight_attn_cross_o_]);
448   auto weight_m = utils::cast<AnfNodePtr>((*equiv)[weight_m_]);
449   auto weight_p = utils::cast<AnfNodePtr>((*equiv)[weight_p_]);
450   if (is_position_bias_) {
451     position_bias = utils::cast<AnfNodePtr>((*equiv)[position_bias_]);
452     position_bias_cross = utils::cast<AnfNodePtr>((*equiv)[position_bias_cross_]);
453   } else {
454     bias_attn_o = utils::cast<AnfNodePtr>((*equiv)[bias_attn_o_]);
455     bias_attn_qkv = utils::cast<AnfNodePtr>((*equiv)[bias_attn_qkv_]);
456     bias_attn_cross_qkv = utils::cast<AnfNodePtr>((*equiv)[bias_attn_cross_qkv_]);
457     bias_attn_cross_o = utils::cast<AnfNodePtr>((*equiv)[bias_attn_cross_o_]);
458     bias_m = utils::cast<AnfNodePtr>((*equiv)[bias_m_]);
459     bias_p = utils::cast<AnfNodePtr>((*equiv)[bias_p_]);
460     beta1 = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
461     beta2 = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
462     beta3 = utils::cast<AnfNodePtr>((*equiv)[beta3_]);
463     if (is_layernorm_) beta4 = utils::cast<AnfNodePtr>((*equiv)[beta4_]);
464   }
465   auto gamma1 = utils::cast<AnfNodePtr>((*equiv)[gamma1_]);
466   auto gamma2 = utils::cast<AnfNodePtr>((*equiv)[gamma2_]);
467   auto gamma3 = utils::cast<AnfNodePtr>((*equiv)[gamma3_]);
468   if (is_layernorm_) gamma4 = utils::cast<AnfNodePtr>((*equiv)[gamma4_]);
469 
470   input_mask = mask ? utils::cast<AnfNodePtr>((*equiv)[mask_]) : nullptr;
471   auto cross_mask = utils::cast<AnfNodePtr>((*equiv)[cross_mask_]);
472   auto base_shape_ptr = weight_m->Shape();
473   MS_EXCEPTION_IF_NULL(base_shape_ptr);
474   auto input_shape_ptr = base_shape_ptr->cast<abstract::ShapePtr>();
475   MS_EXCEPTION_IF_NULL(input_shape_ptr);
476   auto input_shape = input_shape_ptr->shape();
477   MS_ASSERT(input_shape != nullptr);
478   int ffn_hidden_size = (int64_t)input_shape[1];
479   auto decoder_layer_prim = CreatePrim(func_graph, equiv, post_layernorm, ffn_hidden_size);
480   MS_CHECK_TRUE_RET(decoder_layer_prim != nullptr, nullptr);
481   auto decoder_layer_prim_c = decoder_layer_prim->GetPrim();
482   MS_CHECK_TRUE_RET(decoder_layer_prim_c != nullptr, nullptr);
483   auto value_node = NewValueNode(decoder_layer_prim_c);
484   MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
485   std::vector<AnfNodePtr> new_node_inputs = {value_node, input, gamma1};
486   if (is_position_bias_) {
487     new_node_inputs.insert(new_node_inputs.end(), {weight_qkv});
488     if (mask) new_node_inputs.push_back(input_mask);
489     new_node_inputs.insert(new_node_inputs.end(),
490                            {position_bias, weight_attn_o, gamma2, encoder_output, weight_attn_q, weight_attn_kv});
491     if (mask) new_node_inputs.push_back(cross_mask);
492     new_node_inputs.insert(new_node_inputs.end(),
493                            {position_bias_cross, weight_attn_cross_o, gamma3, weight_m, weight_p});
494     if (is_layernorm_) new_node_inputs.push_back(gamma4);
495   } else {
496     new_node_inputs.insert(new_node_inputs.end(), {beta1, weight_qkv, bias_attn_qkv});
497     if (mask) new_node_inputs.push_back(input_mask);
498     new_node_inputs.insert(new_node_inputs.end(), {weight_attn_o, bias_attn_o, gamma2, beta2, encoder_output,
499                                                    weight_attn_q, weight_attn_kv, bias_attn_cross_qkv});
500     if (mask) new_node_inputs.push_back(cross_mask);
501     new_node_inputs.insert(new_node_inputs.end(),
502                            {weight_attn_cross_o, bias_attn_cross_o, gamma3, beta3, weight_m, bias_m, weight_p, bias_p});
503     if (is_layernorm_) new_node_inputs.insert(new_node_inputs.end(), {gamma4, beta4});
504   }
505   auto new_node = func_graph->NewCNode(new_node_inputs);
506   MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
507   auto old_node = node->cast<CNodePtr>();
508   MS_CHECK_TRUE_RET(old_node->abstract() != nullptr, nullptr);
509   new_node->set_abstract(old_node->abstract()->Clone());
510   new_node->set_fullname_with_scope(node->fullname_with_scope() + "/decoder_layer");
511   return new_node;
512 }
513 }  // namespace mindspore::opt
514