• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/encoder_layer_fusion.h"
19 #include <functional>
20 #include <utility>
21 #include <vector>
22 #include <algorithm>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/other_ops.h"
25 #include "mindspore/core/ops/nn_optimizer_ops.h"
26 #include "mindspore/core/ops/nn_ops.h"
27 #include "mindspore/core/ops/math_ops.h"
28 #include "mindspore/core/ops/lite_ops.h"
29 #include "mindspore/core/ops/comparison_ops.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "mindspore/core/ops/framework_ops.h"
32 #include "tools/optimizer/common/gllo_utils.h"
33 #include "nnacl/op_base.h"
34 #include "ops/tuple_get_item.h"
35 #include "tools/common/tensor_util.h"
36 #include "ops/op_utils.h"
37 
38 namespace mindspore::opt {
39 namespace {
40 const auto &p1 = std::placeholders::_1;
41 }  // namespace
42 
Init() const43 bool EncoderLayerFusion::Init() const {
44   input_ = std::make_shared<Var>("input");
45   MS_CHECK_TRUE_RET(input_ != nullptr, false);
46   expert_ids_ = std::make_shared<Var>("expert_ids");
47   MS_CHECK_TRUE_RET(expert_ids_ != nullptr, false);
48   expert_capacity_ = std::make_shared<Var>("expert_capacity_");
49   MS_CHECK_TRUE_RET(expert_capacity_ != nullptr, false);
50   begin_expert_ids_ = std::make_shared<Var>("begin_expert_ids_");
51   MS_CHECK_TRUE_RET(begin_expert_ids_ != nullptr, false);
52   beta1_ = std::make_shared<Var>("beta1");
53   MS_CHECK_TRUE_RET(beta1_ != nullptr, false);
54   gamma1_ = std::make_shared<Var>("gamma1");
55   MS_CHECK_TRUE_RET(gamma1_ != nullptr, false);
56   beta2_ = std::make_shared<Var>("beta2");
57   MS_CHECK_TRUE_RET(beta2_ != nullptr, false);
58   gamma2_ = std::make_shared<Var>("gamma2");
59   MS_CHECK_TRUE_RET(gamma2_ != nullptr, false);
60   beta3_ = std::make_shared<Var>("beta3");
61   MS_CHECK_TRUE_RET(beta3_ != nullptr, false);
62   gamma3_ = std::make_shared<Var>("gamma3");
63   MS_CHECK_TRUE_RET(gamma3_ != nullptr, false);
64   weight_attn_qkv_ = std::make_shared<Var>("weight_attn_qkv");
65   MS_CHECK_TRUE_RET(weight_attn_qkv_ != nullptr, false);
66   weight_attn_q_ = std::make_shared<Var>("weight_attn_q_");
67   MS_CHECK_TRUE_RET(weight_attn_q_ != nullptr, false);
68   weight_attn_o_ = std::make_shared<CondVar>(IsParamNode, "weight_attn_o");
69   MS_CHECK_TRUE_RET(weight_attn_o_ != nullptr, false);
70   weight_m_ = std::make_shared<CondVar>(IsParamNode, "weight_m");
71   MS_CHECK_TRUE_RET(weight_m_ != nullptr, false);
72   weight_p_ = std::make_shared<CondVar>(IsParamNode, "weight_p");
73   MS_CHECK_TRUE_RET(weight_p_ != nullptr, false);
74   bias_attn_qkv_ = std::make_shared<Var>("bias_attn_qkv");
75   MS_CHECK_TRUE_RET(bias_attn_qkv_ != nullptr, false);
76   bias_attn_o_ = std::make_shared<CondVar>(IsParamNode, "bias_attn_o");
77   MS_CHECK_TRUE_RET(bias_attn_o_ != nullptr, false);
78   bias_m_ = std::make_shared<CondVar>(IsParamNode, "bias_m");
79   MS_CHECK_TRUE_RET(bias_m_ != nullptr, false);
80   bias_p_ = std::make_shared<CondVar>(IsParamNode, "bias_p");
81   MS_CHECK_TRUE_RET(bias_p_ != nullptr, false);
82   mask_ = std::make_shared<Var>("mask");
83   MS_CHECK_TRUE_RET(mask_ != nullptr, false);
84   is_attention_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAttention), "is_attention");
85   MS_CHECK_TRUE_RET(is_attention_ != nullptr, false);
86   is_layernorm1_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm1");
87   MS_CHECK_TRUE_RET(is_layernorm1_ != nullptr, false);
88   is_layernorm2_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm2");
89   MS_CHECK_TRUE_RET(is_layernorm2_ != nullptr, false);
90   position_bias_ = std::make_shared<Var>("position_bias");
91   MS_CHECK_TRUE_RET(position_bias_ != nullptr, false);
92   is_act_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "activation");
93   MS_CHECK_TRUE_RET(is_act_ != nullptr, {});
94   eps1_ = std::make_shared<Var>("eps1_");
95   MS_CHECK_TRUE_RET(eps1_ != nullptr, false);
96   eps2_ = std::make_shared<Var>("eps2_");
97   MS_CHECK_TRUE_RET(eps2_ != nullptr, false);
98   eps3_ = std::make_shared<Var>("eps3_");
99   MS_CHECK_TRUE_RET(eps3_ != nullptr, false);
100   batch_valid_length_ = std::make_shared<Var>("batch_valid_length");
101   MS_CHECK_TRUE_RET(batch_valid_length_ != nullptr, false);
102   k_past_ = std::make_shared<Var>("k_past");
103   MS_CHECK_TRUE_RET(k_past_ != nullptr, false);
104   v_past_ = std::make_shared<Var>("k_past");
105   MS_CHECK_TRUE_RET(v_past_ != nullptr, false);
106   input_q_ = std::make_shared<Var>("input_q");
107   MS_CHECK_TRUE_RET(input_q_ != nullptr, false);
108   embedding_table_ = std::make_shared<Var>("embedding_table");
109   MS_CHECK_TRUE_RET(embedding_table_ != nullptr, false);
110   is_layernorm3_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm2");
111   MS_CHECK_TRUE_RET(is_layernorm3_ != nullptr, false);
112   position_ids_ = std::make_shared<Var>("position_ids_");
113   MS_CHECK_TRUE_RET(position_ids_ != nullptr, false);
114   embedding_table_input_ = std::make_shared<Var>("embedding_table_input_");
115   MS_CHECK_TRUE_RET(embedding_table_input_ != nullptr, false);
116   current_index_ = std::make_shared<Var>("current_index_");
117   MS_CHECK_TRUE_RET(current_index_ != nullptr, false);
118   embedding_table_pos_ = std::make_shared<Var>("embedding_table_pos_");
119   MS_CHECK_TRUE_RET(embedding_table_pos_ != nullptr, false);
120   return true;
121 }
122 
getTuple(bool post_layernorm,bool layernorm_fusion=false,bool is_position_bias=false) const123 VectorRef EncoderLayerFusion::getTuple(bool post_layernorm, bool layernorm_fusion = false,
124                                        bool is_position_bias = false) const {
125   auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
126   MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
127   auto var1 = std::make_shared<Var>("var1-reshape");
128   MS_CHECK_TRUE_RET(var1 != nullptr, {});
129   auto reshape1 = VectorRef({is_reshape1, input_, var1});
130   if (post_layernorm && !is_position_bias) {
131     return reshape1;
132   }
133   if (!layernorm_fusion) {
134     return DefineLayerNorm(is_position_bias, reshape1, gamma1_, beta1_, eps1_);
135   }
136   auto layer_norm = VectorRef({is_layernorm1_, reshape1, gamma1_, beta1_});
137   auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
138   auto var_tuple = std::make_shared<Var>("var_tuple");
139   auto tuple = VectorRef({is_tuple, layer_norm, var_tuple});
140   return tuple;
141 }
142 
DefineLayerNorm(bool is_position_bias,BaseRef input,VarPtr gamma,VarPtr beta,VarPtr eps) const143 VectorRef EncoderLayerFusion::DefineLayerNorm(bool is_position_bias, BaseRef input, VarPtr gamma, VarPtr beta,
144                                               VarPtr eps) const {
145   auto var1 = std::make_shared<Var>("var1");
146   MS_CHECK_TRUE_RET(var1 != nullptr, {});
147   auto is_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "reduce");
148   MS_CHECK_TRUE_RET(is_reduce != nullptr, {});
149   auto reduce1 = VectorRef({is_reduce, input, var1});
150   auto is_sub = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSubFusion), "sub-f");
151   MS_CHECK_TRUE_RET(is_sub != nullptr, {});
152   auto sub = VectorRef({is_sub, input, reduce1});
153   auto is_sqr = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSquare), "sqr");
154   MS_CHECK_TRUE_RET(is_sqr != nullptr, {});
155   auto sqr = (is_position_bias) ? VectorRef({is_sqr, input}) : VectorRef({is_sqr, sub});
156   auto var2 = std::make_shared<Var>("var2");
157   MS_CHECK_TRUE_RET(var2 != nullptr, {});
158   auto is_reduce2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "reduce2");
159   MS_CHECK_TRUE_RET(is_reduce2 != nullptr, {});
160   auto reduce2 = VectorRef({is_reduce2, sqr, var2});
161   auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is-add");
162   MS_CHECK_TRUE_RET(is_add != nullptr, {});
163   auto add = VectorRef({is_add, reduce2, eps});
164   auto is_sqr2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSqrt), "sqr2");
165   MS_CHECK_TRUE_RET(is_sqr2 != nullptr, {});
166   auto sqr2 = VectorRef({is_sqr2, add});
167   auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "real-div");
168   MS_CHECK_TRUE_RET(is_div != nullptr, {});
169   auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "mul");
170   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
171   if (is_position_bias) {
172     auto real_div = VectorRef({is_div, input, sqr2});
173     auto mul = VectorRef({is_mul, real_div, gamma});
174     return mul;
175   }
176   auto real_div = VectorRef({is_div, sub, sqr2});
177   auto mul = VectorRef({is_mul, real_div, gamma});
178   auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is-add");
179   MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
180   auto add2 = VectorRef({is_add2, mul, beta});
181   return add2;
182 }
183 
DefinePatternInitReset(VectorRef input,bool value_reset,bool key_reset) const184 VectorRef EncoderLayerFusion::DefinePatternInitReset(VectorRef input, bool value_reset, bool key_reset) const {
185   auto is_cast = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimCast), "cast_init");
186   auto var_cast = std::make_shared<Var>("var_cast");
187   auto cast = VectorRef({is_cast, init_reset_, var_cast});
188   auto is_mul_k = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "mul_k");
189   auto mul_k = VectorRef({is_mul_k, k_past_, cast});
190   auto is_assign_k = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAssign), "assign_k");
191   auto var_assign_k = std::make_shared<Var>("var_assign");
192   auto assign_k = VectorRef({is_assign_k, k_past_, mul_k, var_assign_k});
193   if (key_reset) return assign_k;
194   auto is_depend_k = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend_k");
195   auto depend_k = VectorRef({is_depend_k, input, assign_k});
196   auto is_mul_v = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "mul_v");
197   auto mul_v = VectorRef({is_mul_v, v_past_, cast});
198   auto is_assign_v = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAssign), "assign_v");
199   auto var_assign_v = std::make_shared<Var>("var_assign");
200   auto assign_v = VectorRef({is_assign_v, v_past_, mul_v, var_assign_v});
201   if (value_reset) return assign_v;
202   auto is_depend_kv = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend_kv");
203   auto depend_kv = VectorRef({is_depend_kv, depend_k, assign_v});
204   return depend_kv;
205 }
206 
DefineBatchValidLength(const BaseRef & input) const207 BaseRef EncoderLayerFusion::DefineBatchValidLength(const BaseRef &input) const {
208   auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshpae");
209   MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
210   auto var = std::make_shared<Var>("var");
211   MS_CHECK_TRUE_RET(var != nullptr, {});
212   auto reshape = VectorRef({is_reshape, batch_valid_length_, var});
213   auto var2 = std::make_shared<Var>("var2");
214   MS_CHECK_TRUE_RET(var2 != nullptr, {});
215   auto is_less = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLessEqual), "is_less_eq");
216   MS_CHECK_TRUE_RET(is_less != nullptr, {});
217   auto less = VectorRef({is_less, var2, reshape});
218   auto var3 = std::make_shared<Var>("var3");
219   MS_CHECK_TRUE_RET(var3 != nullptr, {});
220   auto is_cast = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimCast), "is_cast");
221   MS_CHECK_TRUE_RET(is_cast != nullptr, {});
222   auto cast = VectorRef({is_cast, less, var3});
223   auto var4 = std::make_shared<Var>("var4");
224   MS_CHECK_TRUE_RET(var4 != nullptr, {});
225   auto is_expand_dims = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims), "is_expand_dims");
226   MS_CHECK_TRUE_RET(is_expand_dims != nullptr, {});
227   auto expand_dims = VectorRef({is_expand_dims, cast, var4});
228   auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "is_mul");
229   MS_CHECK_TRUE_RET(is_mul != nullptr, {});
230   auto mul = VectorRef({is_mul, input, expand_dims});
231   return mul;
232 }
233 
DefinePatternMoETopKRouter(VectorRef input) const234 VectorRef EncoderLayerFusion::DefinePatternMoETopKRouter(VectorRef input) const {
235   auto var_onehot1 = std::make_shared<Var>("var_onehot1");
236   auto var_onehot2 = std::make_shared<Var>("var_onehot2");
237   auto var_onehot3 = std::make_shared<Var>("var_onehot3");
238   auto var_onehot4 = std::make_shared<Var>("var_onehot4");
239   auto is_onehot = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimOneHot), "is_onehot");
240   MS_CHECK_TRUE_RET(is_onehot != nullptr, {});
241   auto onehot = VectorRef({is_onehot, input, var_onehot1, var_onehot2, var_onehot3});
242   auto is_cumsum = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimCumSum), "is_transpose1");
243   MS_CHECK_TRUE_RET(is_cumsum != nullptr, {});
244   auto var_cumsum = std::make_shared<Var>("var_cumsum");
245   MS_CHECK_TRUE_RET(var_cumsum != nullptr, {});
246   auto cumsum = VectorRef({is_cumsum, onehot, var_cumsum});
247   auto is_mul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "is_mul1");
248   MS_CHECK_TRUE_RET(is_mul1 != nullptr, {});
249   auto mul_fusion1 = VectorRef({is_mul1, cumsum, onehot});
250   auto is_less = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess), "is_less");
251   MS_CHECK_TRUE_RET(is_less != nullptr, {});
252   auto less = VectorRef({is_less, mul_fusion1, expert_capacity_});
253   auto is_cast3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimCast), "is_cast3");
254   MS_CHECK_TRUE_RET(is_cast3 != nullptr, {});
255   auto var_cast3 = std::make_shared<Var>("var_cast3");
256   auto cast3 = VectorRef({is_cast3, less, var_cast3});
257   auto is_mul4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "is_mul4");
258   MS_CHECK_TRUE_RET(is_mul4 != nullptr, {});
259   auto mul_fusion4 = VectorRef({is_mul4, cast3, onehot});
260   auto is_reduce5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "is_reduce5");
261   MS_CHECK_TRUE_RET(is_reduce5 != nullptr, {});
262   auto var_reduce5 = std::make_shared<Var>("var_reduce5");
263   auto reduce5 = VectorRef({is_reduce5, mul_fusion4, var_reduce5});
264   auto is_expand_dims1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims), "is_expand_dims1");
265   MS_CHECK_TRUE_RET(is_expand_dims1 != nullptr, {});
266   auto var_expand_dims = std::make_shared<Var>("var_expand_dims1");
267   auto expand_dims1 = VectorRef({is_expand_dims1, reduce5, var_expand_dims});
268   auto is_mul7 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "is_mul7");
269   MS_CHECK_TRUE_RET(is_mul7 != nullptr, {});
270   auto mul_fusion7 = VectorRef({is_mul7, expand_dims1, onehot});
271   auto is_onehot2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimOneHot), "is_onehot2");
272   MS_CHECK_TRUE_RET(is_onehot2 != nullptr, {});
273   auto is_cast2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimCast), "is_cast2");
274   MS_CHECK_TRUE_RET(is_cast2 != nullptr, {});
275   auto var_cast2 = std::make_shared<Var>("var_cast2");
276   auto cast2 = VectorRef({is_cast2, mul_fusion1, var_cast2});
277   auto onehot2 = VectorRef({is_onehot2, cast2, var_onehot4, var_onehot2, var_onehot3});
278   auto is_expand_dims2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims), "is_expand_dims2");
279   MS_CHECK_TRUE_RET(is_expand_dims2 != nullptr, {});
280   auto expand_dims2 = VectorRef({is_expand_dims2, mul_fusion7, var_expand_dims});
281   auto is_mul8 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "is_mul8");
282   MS_CHECK_TRUE_RET(is_mul8 != nullptr, {});
283   auto mul_fusion8 = VectorRef({is_mul8, expand_dims2, onehot2});
284   return mul_fusion8;
285 }
286 
DefinePatternMoERouter(VectorRef input_layernorm) const287 VectorRef EncoderLayerFusion::DefinePatternMoERouter(VectorRef input_layernorm) const {
288   auto is_relu = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "depend");
289   auto relu = VectorRef({is_relu, expert_ids_});
290   auto is_stride_slice = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice), "is_stride_slice");
291   MS_CHECK_TRUE_RET(is_stride_slice != nullptr, {});
292   auto var_stride_slice1 = std::make_shared<Var>("var_stride_slice1");
293   MS_CHECK_TRUE_RET(var_stride_slice1 != nullptr, {});
294   auto var_stride_slice2 = std::make_shared<Var>("var_stride_slice2");
295   MS_CHECK_TRUE_RET(var_stride_slice2 != nullptr, {});
296   auto var_stride_slice3 = std::make_shared<Var>("var_stride_slice3");
297   MS_CHECK_TRUE_RET(var_stride_slice3 != nullptr, {});
298   auto strid_slice = VectorRef({is_stride_slice, relu, begin_expert_ids_, var_stride_slice2, var_stride_slice3});
299   auto is_depend = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
300   auto depend = VectorRef({is_depend, strid_slice, input_layernorm});
301   auto is_reshape_router = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "router-reshpae");
302   MS_CHECK_TRUE_RET(is_reshape_router != nullptr, {});
303   auto var1 = std::make_shared<Var>("var1");
304   MS_CHECK_TRUE_RET(var1 != nullptr, {});
305   auto router_reshape = VectorRef({is_reshape_router, depend, var1});
306   return DefinePatternMoETopKRouter(router_reshape);
307 }
308 
DefinePatternMoEFfn(VectorRef input_reshape,bool gelu=false) const309 VectorRef EncoderLayerFusion::DefinePatternMoEFfn(VectorRef input_reshape, bool gelu = false) const {
310   auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
311   auto matmul1 = VectorRef({is_matmul1, input_reshape, weight_m_});
312   auto is_add1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add1");
313   MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
314   auto add = VectorRef({is_add1, matmul1, bias_m_});
315   auto is_act = (gelu) ? std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "is_Gelu")
316                        : std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimFastGeLU), "is_FastGelu");
317   auto act = VectorRef({is_act, add});
318   auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
319   MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
320   auto matmul2 = VectorRef({is_matmul2, act, weight_p_});
321   auto is_all_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAllReduce), "is_all_reduce");
322   auto all_reduce = VectorRef({is_all_reduce, matmul2});
323   auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add2");
324   MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
325   auto add2 = VectorRef({is_add2, all_reduce, bias_p_});
326   return add2;
327 }
328 
DefinePatternMoE(VectorRef input_layernorm,bool multi_batch=false,bool gelu=false) const329 VectorRef EncoderLayerFusion::DefinePatternMoE(VectorRef input_layernorm, bool multi_batch = false,
330                                                bool gelu = false) const {
331   auto router_output = DefinePatternMoERouter(input_layernorm);
332   auto is_reshape10 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape1");
333   MS_CHECK_TRUE_RET(is_reshape10 != nullptr, {});
334   auto var_reshape10 = std::make_shared<Var>("var_reshape10");
335   MS_CHECK_TRUE_RET(var_reshape10 != nullptr, {});
336   auto reshape10 = VectorRef({is_reshape10, router_output, var_reshape10});
337   auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul168");
338   MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
339   auto matmul1 = VectorRef({is_matmul1, input_layernorm, reshape10});
340   auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape3");
341   MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
342   auto var_reshape3 = std::make_shared<Var>("var_reshape3");
343   MS_CHECK_TRUE_RET(var_reshape3 != nullptr, {});
344   auto reshape3 = VectorRef({is_reshape3, matmul1, var_reshape3});
345   auto is_transpose1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "is_transpose1");
346   MS_CHECK_TRUE_RET(is_transpose1 != nullptr, {});
347   auto var_transpose1 = std::make_shared<Var>("var_transpose1");
348   MS_CHECK_TRUE_RET(var_reshape3 != nullptr, {});
349   auto transpose1 = VectorRef({is_transpose1, reshape3, var_transpose1});
350   auto is_reshape4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape4");
351   MS_CHECK_TRUE_RET(is_reshape4 != nullptr, {});
352   auto var_reshape4 = std::make_shared<Var>("var_reshape4");
353   MS_CHECK_TRUE_RET(var_reshape4 != nullptr, {});
354   auto reshape4 = VectorRef({is_reshape4, transpose1, var_reshape4});
355   auto ffn_output = DefinePatternMoEFfn(reshape4, gelu);
356   auto is_reshape6 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape6");
357   MS_CHECK_TRUE_RET(is_reshape6 != nullptr, {});
358   auto var_reshape6 = std::make_shared<Var>("var_reshape6");
359   MS_CHECK_TRUE_RET(var_reshape6 != nullptr, {});
360   auto reshape6 = VectorRef({is_reshape6, ffn_output, var_reshape6});
361   VectorRef transpose2;
362   if (multi_batch) {
363     auto is_transpose2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "is_transpose2");
364     MS_CHECK_TRUE_RET(is_transpose2 != nullptr, {});
365     auto var_transpose2 = std::make_shared<Var>("var_transpose2");
366     MS_CHECK_TRUE_RET(var_transpose2 != nullptr, {});
367     transpose2 = VectorRef({is_transpose2, reshape6, var_transpose2});
368   } else {
369     transpose2 = reshape6;
370   }
371   auto is_reshape8 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape8");
372   MS_CHECK_TRUE_RET(is_reshape8 != nullptr, {});
373   auto var_reshape8 = std::make_shared<Var>("var_reshape8");
374   MS_CHECK_TRUE_RET(var_reshape8 != nullptr, {});
375   auto reshape8 = VectorRef({is_reshape8, transpose2, var_reshape8});
376   auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul2");
377   MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
378   auto matmul2 = VectorRef({is_matmul2, reshape10, reshape8});
379   return matmul2;
380 }
381 
DefineDependKV(VectorRef input,VectorRef deppend_v_input,bool moe=false) const382 VectorRef EncoderLayerFusion::DefineDependKV(VectorRef input, VectorRef deppend_v_input, bool moe = false) const {
383   auto is_tuple_k = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
384   auto var_tuple_k = std::make_shared<Var>("var_tuple");
385   auto tuple_k = VectorRef({is_tuple_k, input, var_tuple_k});
386   auto is_assign_k = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAssign), "assign_k_fuse");
387   auto var_assign_k = std::make_shared<Var>("var_assign_k");
388   auto assign_k = VectorRef({is_assign_k, k_past_, DefineBatchValidLength(tuple_k), var_assign_k});
389   auto is_tuple_v = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
390   auto var_tuple_v = std::make_shared<Var>("var_tuple");
391   auto tuple_v = VectorRef({is_tuple_v, input, var_tuple_v});
392   auto is_assign_v = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAssign), "assign_k_fuse");
393   auto var_assign_v = std::make_shared<Var>("var_assign_v");
394   auto assign_v = VectorRef({is_assign_v, v_past_, DefineBatchValidLength(tuple_v), var_assign_v});
395   auto is_depend_v_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend_kv");
396   auto depend_v_mul = VectorRef({is_depend_v_mul, deppend_v_input, assign_k});
397   auto is_depend_kv_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend_kv");
398   return VectorRef({is_depend_kv_mul, depend_v_mul, assign_v});
399 }
400 
DefinePatternSigmaFfn(BaseRef input,bool gelu=false,bool distributed=false) const401 VectorRef EncoderLayerFusion::DefinePatternSigmaFfn(BaseRef input, bool gelu = false, bool distributed = false) const {
402   auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder2");
403   MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
404   auto var2 = std::make_shared<Var>("var2");
405   MS_CHECK_TRUE_RET(var2 != nullptr, {});
406   auto reshape2 = VectorRef({is_reshape2, input, var2});
407   auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
408   auto matmul1 = VectorRef({is_matmul1, reshape2, weight_m_, bias_m_});
409   auto is_act = (gelu) ? std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "is_Gelu")
410                        : std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimFastGeLU), "is_FastGelu");
411   MS_CHECK_TRUE_RET(is_act != nullptr, {});
412   auto act = VectorRef({is_act, matmul1});
413   auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
414   MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
415   auto matmul2 = VectorRef({is_matmul2, act, weight_p_});
416   if (!distributed) matmul2.push_back(bias_p_);
417   auto is_all_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAllReduce), "is_all_reduce");
418   auto all_reduce = VectorRef({is_all_reduce, matmul2});
419   auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add2");
420   auto add2 = VectorRef({is_add2, all_reduce, bias_p_});
421   auto is_reshape5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder5");
422   MS_CHECK_TRUE_RET(is_reshape5 != nullptr, {});
423   auto var5 = std::make_shared<Var>("var5");
424   MS_CHECK_TRUE_RET(var5 != nullptr, {});
425   auto reshape5 = (distributed) ? VectorRef({is_reshape5, add2, var5}) : VectorRef({is_reshape5, matmul2, var5});
426   return reshape5;
427 }
428 
DefinePatternEncoderSigma(bool moe=false,bool use_past=true,bool distributed=false,bool is_layer_norm=false,bool query_layer=false,bool multi_batch=false,bool first_encoder=false,bool gelu=false) const429 VectorRef EncoderLayerFusion::DefinePatternEncoderSigma(bool moe = false, bool use_past = true,
430                                                         bool distributed = false, bool is_layer_norm = false,
431                                                         bool query_layer = false, bool multi_batch = false,
432                                                         bool first_encoder = false, bool gelu = false) const {
433   VectorRef input, q_input;
434   if (first_encoder) {
435     input = DefineFirstEncoder(false);
436   }
437   auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
438   MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
439   auto var = std::make_shared<Var>("var");
440   MS_CHECK_TRUE_RET(var != nullptr, {});
441   if (query_layer) {
442     auto var_gather_emb = std::make_shared<Var>("var gather_emb");
443     MS_CHECK_TRUE_RET(var_gather_emb != nullptr, {});
444     auto is_gather_emb = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimGather), "is_gather_emb");
445     MS_CHECK_TRUE_RET(is_gather_emb != nullptr, {});
446     auto gather_emb = VectorRef({is_gather_emb, embedding_table_pos_, input_q_, var_gather_emb});
447     auto is_reshape_emb = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder7");
448     MS_CHECK_TRUE_RET(is_reshape_emb != nullptr, {});
449     auto var_reshape_emb = std::make_shared<Var>("var reshape_emb");
450     MS_CHECK_TRUE_RET(var_reshape_emb != nullptr, {});
451     q_input = VectorRef({is_reshape_emb, gather_emb, var_reshape_emb});
452   }
453   auto reshape = (first_encoder) ? VectorRef({is_reshape, DefineLayerNorm(false, input, gamma1_, beta1_, eps1_), var})
454                                  : VectorRef({is_reshape, DefineLayerNorm(false, input_, gamma1_, beta1_, eps1_), var});
455   auto attention = (query_layer) ? VectorRef({is_attention_, q_input, reshape, reshape, weight_attn_q_,
456                                               weight_attn_qkv_, weight_attn_o_, bias_attn_qkv_, bias_attn_o_, mask_})
457                                  : VectorRef({is_attention_, reshape, reshape, reshape, weight_attn_qkv_,
458                                               weight_attn_o_, bias_attn_qkv_, bias_attn_o_, mask_});
459 
460   auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
461 
462   auto var_tuple = std::make_shared<Var>("var_tuple");
463   auto tuple = VectorRef({is_tuple, attention, var_tuple});
464   auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
465   auto add = (first_encoder) ? VectorRef({is_add, input, tuple}) : VectorRef({is_add, input_, tuple});
466   auto layer_norm2 = DefineLayerNorm(false, add, gamma2_, beta2_, eps2_);
467   auto ffn_output =
468     (moe) ? DefinePatternMoE(layer_norm2, multi_batch, gelu) : DefinePatternSigmaFfn(layer_norm2, gelu, distributed);
469   auto depend_kv_mul = DefineDependKV(attention, ffn_output, moe);
470 
471   auto is_add3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add3");
472   auto add3 = VectorRef({is_add3, add, depend_kv_mul});
473   if (is_layer_norm) return DefineLayerNorm(false, add3, gamma3_, beta3_, eps3_);
474   if (query_layer) {
475     auto is_reshape7 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder7");
476     MS_CHECK_TRUE_RET(is_reshape7 != nullptr, {});
477     auto var7 = std::make_shared<Var>("var7");
478     MS_CHECK_TRUE_RET(var7 != nullptr, {});
479     auto reshape7 = VectorRef({is_reshape7, add3, var7});
480     add3 = reshape7;
481     auto is_gather = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimGather), "is_gather");
482     MS_CHECK_TRUE_RET(is_gather != nullptr, {});
483     auto var_gather2 = std::make_shared<Var>("var_gather2");
484     auto gather = VectorRef({is_gather, add3, current_index_, var_gather2});
485     auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul3");
486     MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
487     auto matmul3 = VectorRef({is_matmul3, gather, embedding_table_});
488     return matmul3;
489   }
490   return add3;
491 }
492 
DefinePatternEncoderAlpha(bool moe=false,bool distributed=false,bool is_layer_norm=false,bool query_layer=false,bool use_past=false) const493 VectorRef EncoderLayerFusion::DefinePatternEncoderAlpha(bool moe = false, bool distributed = false,
494                                                         bool is_layer_norm = false, bool query_layer = false,
495                                                         bool use_past = false) const {
496   VectorRef add2;
497   auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
498   MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
499   auto var1 = std::make_shared<Var>("var1-reshape");
500   MS_CHECK_TRUE_RET(var1 != nullptr, {});
501   auto reshape1 = VectorRef({is_reshape1, input_, var1});
502   VectorRef layer_norm = (query_layer) ? VectorRef({is_layernorm1_, input_, gamma1_, beta1_})
503                                        : VectorRef({is_layernorm1_, reshape1, gamma1_, beta1_});
504   auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
505   auto var_tuple = std::make_shared<Var>("var_tuple");
506   auto tuple = VectorRef({is_tuple, layer_norm, var_tuple});
507   auto is_deppend = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
508   auto deppend = VectorRef({is_deppend, tuple, k_past_});
509   auto is_deppend2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
510   auto deppend2 = VectorRef({is_deppend2, deppend, v_past_});
511   auto attention = (query_layer) ? VectorRef({is_attention_, input_q_, DefinePatternInitReset(tuple, false, false),
512                                               DefinePatternInitReset(tuple, false, false), weight_attn_q_,
513                                               weight_attn_qkv_, weight_attn_o_, bias_attn_qkv_, bias_attn_o_, mask_})
514                                  : VectorRef({is_attention_, deppend2, deppend2, deppend2, weight_attn_qkv_,
515                                               weight_attn_o_, bias_attn_qkv_, bias_attn_o_, mask_});
516   auto is_tuple1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
517   auto var_tuple1 = std::make_shared<Var>("var_tuple1");
518   auto tuple1 = VectorRef({is_tuple1, attention, var_tuple1});
519   auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
520   auto add = (query_layer) ? VectorRef({is_add, input_, tuple1}) : VectorRef({is_add, reshape1, tuple1});
521   auto layer_norm2 = VectorRef({is_layernorm2_, add, gamma2_, beta2_});
522   auto is_tuple2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item2");
523   auto var_tuple2 = std::make_shared<Var>("var_tuple2");
524   auto tuple2 = VectorRef({is_tuple2, layer_norm2, var_tuple2});
525   auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
526   auto matmul1 = VectorRef({is_matmul1, tuple2, weight_m_, bias_m_});
527   auto is_act = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "is_FastGelu");
528   MS_CHECK_TRUE_RET(is_act != nullptr, {});
529   auto act = VectorRef({is_act, matmul1});
530   auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
531   auto matmul2 =
532     (distributed) ? VectorRef({is_matmul2, act, weight_p_}) : VectorRef({is_matmul2, act, weight_p_, bias_p_});
533   if (distributed) {
534     auto is_all_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAllReduce), "is_all_reduce");
535     auto all_reduce = VectorRef({is_all_reduce, matmul2});
536     auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add2");
537     add2 = VectorRef({is_add2, all_reduce, bias_p_});
538   }
539   if (use_past && !query_layer) {
540     auto is_deppend3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend3");
541     auto deppend3 = VectorRef({is_deppend3, v_past_, v_past_});
542     auto is_deppend4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
543     auto deppend4 =
544       (distributed) ? VectorRef({is_deppend4, add2, deppend3}) : VectorRef({is_deppend4, matmul2, deppend3});
545     auto is_deppend5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
546     auto deppend5 = VectorRef({is_deppend5, k_past_, k_past_});
547     auto is_deppend6 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
548     matmul2 = VectorRef({is_deppend6, deppend4, deppend5});
549   }
550   auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder2");
551   MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
552   auto var2 = std::make_shared<Var>("var2");
553   MS_CHECK_TRUE_RET(var2 != nullptr, {});
554   auto reshape2 = VectorRef({is_reshape2, add, var2});
555   auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder3");
556   MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
557   auto var3 = std::make_shared<Var>("var3");
558   MS_CHECK_TRUE_RET(var3 != nullptr, {});
559   auto reshape3 = VectorRef({is_reshape3, matmul2, var3});
560   auto depend_kv_mul =
561     (distributed) ? DefineDependKV(attention, add2, true) : DefineDependKV(attention, matmul2, false);
562   auto is_add3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add3");
563   auto add3 = (query_layer) ? VectorRef({is_add3, add, depend_kv_mul}) : VectorRef({is_add3, reshape2, reshape3});
564   if (is_layer_norm) {
565     auto is_reshape_norm = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder3");
566     MS_CHECK_TRUE_RET(is_reshape_norm != nullptr, {});
567     auto var_norm = std::make_shared<Var>("var3");
568     MS_CHECK_TRUE_RET(var_norm != nullptr, {});
569     auto reshape_norm = VectorRef({is_reshape_norm, add3, var_norm});
570     auto layer_norm3 = VectorRef({is_layernorm3_, reshape_norm, gamma3_, beta3_});
571     return layer_norm3;
572   }
573   if (query_layer) {
574     auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul3");
575     MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
576     auto matmul3 = VectorRef({is_matmul3, add3, embedding_table_});
577     return matmul3;
578   }
579   return add3;
580 }
581 
DefineFirstEncoder(bool distributed=false) const582 VectorRef EncoderLayerFusion::DefineFirstEncoder(bool distributed = false) const {
583   auto is_depend = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
584   auto depend = VectorRef({is_depend, input_, expert_ids_});
585   auto var1 = std::make_shared<Var>("var1");
586   MS_CHECK_TRUE_RET(var1 != nullptr, {});
587   auto is_gather = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimGather), "is_gather");
588   MS_CHECK_TRUE_RET(is_gather != nullptr, {});
589   auto gather = (distributed) ? VectorRef({is_gather, embedding_table_input_, depend, var1})
590                               : VectorRef({is_gather, embedding_table_input_, input_, var1});
591   auto is_gather2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimGather), "is_gather2");
592   MS_CHECK_TRUE_RET(is_gather2 != nullptr, {});
593   auto gather2 = VectorRef({is_gather2, embedding_table_pos_, position_ids_, var1});
594   auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
595   MS_CHECK_TRUE_RET(is_add != nullptr, {});
596   auto add = VectorRef({is_add, gather, gather2});
597   return add;
598 }
DefinePatternEncoderLayer(bool post_layernorm=true,bool layernorm_fusion=false,bool is_position_bias=false,bool mask=true,bool is_layer_norm=false) const599 VectorRef EncoderLayerFusion::DefinePatternEncoderLayer(bool post_layernorm = true, bool layernorm_fusion = false,
600                                                         bool is_position_bias = false, bool mask = true,
601                                                         bool is_layer_norm = false) const {
602   VectorRef tuple, tuple2, tuple3, reshape2, matmul1, inputs, layer_norm2;
603   auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
604   MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
605   auto var1 = std::make_shared<Var>("var1");
606   MS_CHECK_TRUE_RET(var1 != nullptr, {});
607   auto reshape1 = VectorRef({is_reshape1, input_, var1});
608   if (!is_position_bias) {
609     inputs = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
610                         getTuple(post_layernorm, layernorm_fusion, is_position_bias),
611                         getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_, weight_attn_o_,
612                         bias_attn_qkv_, bias_attn_o_});
613   } else {
614     inputs = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
615                         getTuple(post_layernorm, layernorm_fusion, is_position_bias),
616                         getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_, weight_attn_o_,
617                         position_bias_});
618   }
619   if (mask) inputs.push_back(mask_);
620   auto attention = VectorRef(inputs);
621   if (!is_position_bias) {
622     auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
623     auto var_tuple = std::make_shared<Var>("var_tuple");
624     tuple = VectorRef({is_tuple, attention, var_tuple});
625   } else {
626     tuple = attention;
627   }
628   auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
629   auto add = (is_position_bias && post_layernorm)
630                ? VectorRef({is_add, getTuple(post_layernorm, layernorm_fusion, is_position_bias), tuple})
631                : VectorRef({is_add, reshape1, tuple});
632   auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder2");
633   MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
634   auto var2 = std::make_shared<Var>("var2");
635   MS_CHECK_TRUE_RET(var2 != nullptr, {});
636   if (layernorm_fusion) {
637     layer_norm2 = VectorRef({is_layernorm2_, add, gamma2_, beta2_});
638     auto is_tuple2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item2");
639     auto var_tuple2 = std::make_shared<Var>("var_tuple2");
640     tuple2 = VectorRef({is_tuple2, layer_norm2, var_tuple2});
641   } else {
642     tuple2 = DefineLayerNorm(is_position_bias, add, gamma2_, beta2_, eps2_);
643   }
644   auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
645   MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
646   if (is_position_bias) {
647     reshape2 = (post_layernorm) ? VectorRef({is_reshape2, tuple2, var2}) : VectorRef({is_reshape2, add, var2});
648     matmul1 = VectorRef({is_matmul1, tuple2, weight_m_});
649   } else if (post_layernorm || !layernorm_fusion) {
650     reshape2 = VectorRef({is_reshape2, tuple2, var2});
651     matmul1 = VectorRef({is_matmul1, tuple2, weight_m_, bias_m_});
652   } else {
653     reshape2 = VectorRef({is_reshape2, add, var2});
654     matmul1 = VectorRef({is_matmul1, tuple2, weight_m_, bias_m_});
655   }
656   auto act = VectorRef({is_act_, matmul1});
657   auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
658   MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
659   auto matmul2 =
660     (is_position_bias) ? VectorRef({is_matmul2, matmul1, weight_p_}) : VectorRef({is_matmul2, act, weight_p_, bias_p_});
661 
662   auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder3");
663   MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
664   auto var3 = std::make_shared<Var>("var3");
665   MS_CHECK_TRUE_RET(var3 != nullptr, {});
666   auto reshape3 = VectorRef({is_reshape3, matmul2, var3});
667   auto is_add3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add3");
668   auto add3 = VectorRef({is_add3, reshape2, reshape3});
669   if (is_layer_norm) return DefineLayerNorm(is_position_bias, add3, gamma3_, beta3_, eps3_);
670   if (!post_layernorm || !layernorm_fusion) return add3;
671   auto is_reshape4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
672   MS_CHECK_TRUE_RET(is_reshape4 != nullptr, {});
673   auto var4 = std::make_shared<Var>("var4");
674   MS_CHECK_TRUE_RET(var4 != nullptr, {});
675   auto reshape4 = VectorRef({is_reshape4, add3, var4});
676   if (layernorm_fusion) {
677     auto layer_norm = VectorRef({is_layernorm1_, reshape4, gamma1_, beta1_});
678     auto is_tuple3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item3");
679     auto var_tuple3 = std::make_shared<Var>("var_tuple3");
680     tuple3 = VectorRef({is_tuple3, layer_norm, var_tuple3});
681   } else {
682     tuple3 = DefineLayerNorm(is_position_bias, reshape4, gamma1_, beta1_, eps1_);
683   }
684   auto is_reshape5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
685   MS_CHECK_TRUE_RET(is_reshape5 != nullptr, {});
686   auto var5 = std::make_shared<Var>("var5");
687   MS_CHECK_TRUE_RET(var5 != nullptr, {});
688   auto reshape5 = VectorRef({is_reshape5, tuple3, var5});
689   return reshape5;
690 }
691 
DefinePatterns() const692 std::unordered_map<std::string, VectorRef> EncoderLayerFusion::DefinePatterns() const {
693   std::unordered_map<std::string, VectorRef> patterns;
694   if (!Init()) {
695     MS_LOG(ERROR) << "initial member failed.";
696     return patterns;
697   }
698   if (embedding_layer_) {
699     patterns[kPatternSigmaEmbedding] = DefinePatternEncoderSigma(false, true, false, false, false, false, true);
700     patterns[kPatternSigmaEmbeddingDistributed] =
701       DefinePatternEncoderSigma(false, true, true, false, false, false, true);
702     patterns[kPatternSigmaDistributedEmbedding] =
703       DefinePatternEncoderSigma(false, true, true, false, false, false, true);
704     patterns[kPatternSigmaDistributedEmbeddingMB] =
705       DefinePatternEncoderSigma(false, true, true, false, false, true, true);
706   } else {
707     // pangu sigma
708     patterns[kPatternSigmaDistributedEmbeddingMBGELU] =
709       DefinePatternEncoderSigma(false, true, true, false, false, true, true, true);
710     patterns[kPatternSigmaDistributed] = DefinePatternEncoderSigma(false, true, true, false, false);
711     patterns[kPatternSigmaMoeDistributed] = DefinePatternEncoderSigma(true, true, true, false, false);
712     patterns[kPatternSigmaMoeWithLastLayerNormDistributed] = DefinePatternEncoderSigma(true, true, true, true, false);
713     patterns[kPatternSigmaQueryLayerDistributed] = DefinePatternEncoderSigma(false, true, true, false, true);
714     patterns[kPatternSigmaQueryLayerDistributedMoe] = DefinePatternEncoderSigma(true, true, true, false, true);
715 
716     patterns[kPatternSigma] = DefinePatternEncoderSigma(false, true, false, false, false);
717     patterns[kPatternSigmaWithLastLayerNorm] = DefinePatternEncoderSigma(false, true, false, true, false);
718     patterns[kPatternSigmaQuery] = DefinePatternEncoderSigma(false, true, false, false, true);
719     patterns[kPatternSigmaMoe] = DefinePatternEncoderSigma(true, true, false, false, false);
720     patterns[kPatternSigmaWithLastLayerNormDistributed] = DefinePatternEncoderSigma(false, true, true, true, false);
721 
722     patterns[kPatternSigmaMoeWithLastLayerNorm] = DefinePatternEncoderSigma(true, true, true, true, false);
723 
724     patterns[kPatternSigmaQueryLayerMoe] = DefinePatternEncoderSigma(true, true, false, false, true);
725     // multi batch-
726     // fast-gelu
727     patterns[kPatternSigmaMoeWithLastLayerNormDistributedMB] =
728       DefinePatternEncoderSigma(true, true, true, true, false, true);
729     patterns[kPatternSigmaWithLastLayerNormDistributedMB] =
730       DefinePatternEncoderSigma(false, true, true, true, false, true);
731     patterns[kPatternSigmaDistributedMB] = DefinePatternEncoderSigma(false, true, true, false, false, true);
732     patterns[kPatternSigmaQueryLayerDistributedMB] = DefinePatternEncoderSigma(false, true, true, false, true, true);
733     patterns[kPatternSigmaQueryLayerDistributedMBMoe] = DefinePatternEncoderSigma(true, true, true, false, true, true);
734     patterns[kPatternSigmaMoeDistributedMB] = DefinePatternEncoderSigma(true, true, true, false, false, true);
735     // gelu
736     patterns[kPatternSigmaMoeWithLastLayerNormDistributedMBGELU] =
737       DefinePatternEncoderSigma(true, true, true, true, false, true, false, true);
738     patterns[kPatternSigmaWithLastLayerNormDistributedMBGELU] =
739       DefinePatternEncoderSigma(false, true, true, true, false, true, false, true);
740     patterns[kPatternSigmaDistributedMBGELU] =
741       DefinePatternEncoderSigma(false, true, true, false, false, true, false, true);
742     patterns[kPatternSigmaQueryLayerDistributedMBGELU] =
743       DefinePatternEncoderSigma(true, true, true, false, true, true, false, true);
744     patterns[kPatternSigmaMoeDistributedMBGELU] =
745       DefinePatternEncoderSigma(true, true, true, false, false, true, false, true);
746     // pangu alpha
747     patterns[kPatternDistributedAlpha] = DefinePatternEncoderAlpha(false, true, false, false, true);
748     patterns[kPatternDistributedAlphaWithLastLayerNorm] = DefinePatternEncoderAlpha(false, true, true, false, true);
749     patterns[kPatternQueryLayerUsePastDistributed] = DefinePatternEncoderAlpha(false, true, false, true, true);
750     patterns[kPatternQueryLayerUsePast] = DefinePatternEncoderAlpha(false, false, false, true, true);
751     patterns[kPatternEncoderLayerPreNormUsePast] = DefinePatternEncoderAlpha(false, false, false, false, true);
752     patterns[kPatternEncoderLayerUsePastWithLastNorm] = DefinePatternEncoderAlpha(false, false, true, false, true);
753     patterns[kPatternEncoderLayerNormT5Pre] = DefinePatternEncoderLayer(false, false, true, true, true);
754     patterns[kPatternEncoderLayerPre] = DefinePatternEncoderLayer(false);
755     patterns[kPatternEncoderLayerPost] = DefinePatternEncoderLayer(true);
756     patterns[kPatternEncoderLayerPostNorm] = DefinePatternEncoderLayer(true, true);
757     patterns[kPatternEncoderLayerPreNorm] = DefinePatternEncoderLayer(false, true);
758     patterns[kPatternEncoderLayerT5Pre] = DefinePatternEncoderLayer(false, false, true, true);
759     patterns[kPatternEncoderLayerT5Post] = DefinePatternEncoderLayer(true, false, true, true);
760   }
761   return patterns;
762 }
763 
IsUsePastAlpha(const std::string & pattern_name) const764 bool EncoderLayerFusion::IsUsePastAlpha(const std::string &pattern_name) const {
765   if (pattern_name == kPatternDistributedAlpha || pattern_name == kPatternDistributedAlphaWithLastLayerNorm ||
766       pattern_name == kPatternQueryLayerUsePastDistributed || pattern_name == kPatternQueryLayerUsePast ||
767       pattern_name == kPatternEncoderLayerPreNormUsePast || pattern_name == kPatternEncoderLayerUsePastWithLastNorm)
768     return true;
769   return false;
770 }
771 
IsUsePastMB(const std::string & pattern_name) const772 bool EncoderLayerFusion::IsUsePastMB(const std::string &pattern_name) const {
773   if (pattern_name == kPatternSigmaWithLastLayerNormDistributedMB ||
774       pattern_name == kPatternSigmaWithLastLayerNormDistributedMBGELU ||
775       pattern_name == kPatternSigmaQueryLayerDistributedMB || pattern_name == kPatternSigmaQueryLayerDistributedMBMoe ||
776       pattern_name == kPatternSigmaQueryLayerDistributedMBGELU ||
777       pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMB ||
778       pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMBGELU ||
779       pattern_name == kPatternSigmaDistributedEmbeddingMB || pattern_name == kPatternSigmaDistributedEmbeddingMBGELU ||
780       pattern_name == kPatternSigmaDistributedMB || pattern_name == kPatternSigmaDistributedMBGELU ||
781       pattern_name == kPatternSigmaMoeDistributedMB || pattern_name == kPatternSigmaMoeDistributedMBGELU)
782     return true;
783   return false;
784 }
IsUsePast(const std::string & pattern_name) const785 bool EncoderLayerFusion::IsUsePast(const std::string &pattern_name) const {
786   if (pattern_name == kPatternSigmaDistributed || pattern_name == kPatternSigmaDistributedEmbedding ||
787       pattern_name == kPatternSigmaMoeDistributed || pattern_name == kPatternSigmaWithLastLayerNormDistributed ||
788       pattern_name == kPatternSigmaQueryLayerDistributed || pattern_name == kPatternSigmaQueryLayerDistributedMoe ||
789       pattern_name == kPatternSigmaMoeWithLastLayerNormDistributed || pattern_name == kPatternSigma ||
790       pattern_name == kPatternSigmaQuery || pattern_name == kPatternSigmaWithLastLayerNorm ||
791       pattern_name == kPatternSigmaMoe || pattern_name == kPatternSigmaMoeWithLastLayerNorm ||
792       pattern_name == kPatternSigmaWithLastLayerNorm || pattern_name == kPatternSigmaQueryLayerMoe ||
793       pattern_name == kPatternSigmaEmbedding || pattern_name == kPatternSigmaEmbeddingDistributed)
794     return true;
795   return false;
796 }
797 
IsLastLayerNorm(const std::string & pattern_name) const798 bool EncoderLayerFusion::IsLastLayerNorm(const std::string &pattern_name) const {
799   if (pattern_name == kPatternEncoderLayerNormT5Pre || pattern_name == kPatternEncoderLayerUsePastWithLastNorm ||
800       pattern_name == kPatternDistributedAlphaWithLastLayerNorm ||
801       pattern_name == kPatternSigmaWithLastLayerNormDistributed ||
802       pattern_name == kPatternSigmaMoeWithLastLayerNormDistributed ||
803       pattern_name == kPatternSigmaMoeWithLastLayerNorm || pattern_name == kPatternSigmaWithLastLayerNorm ||
804       pattern_name == kPatternSigmaWithLastLayerNormDistributedMB ||
805       pattern_name == kPatternSigmaWithLastLayerNormDistributedMBGELU ||
806       pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMB ||
807       pattern_name == kPatternSigmaWithLastLayerNorm ||
808       pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMBGELU)
809     return true;
810   return false;
811 }
812 
IsLayerNormFusion(const std::string & pattern_name) const813 bool EncoderLayerFusion::IsLayerNormFusion(const std::string &pattern_name) const {
814   if (pattern_name == kPatternEncoderLayerPostNorm || pattern_name == kPatternEncoderLayerPreNorm ||
815       pattern_name == kPatternEncoderLayerPreNormUsePast || pattern_name == kPatternQueryLayerUsePast ||
816       pattern_name == kPatternEncoderLayerUsePastWithLastNorm || pattern_name == kPatternDistributedAlpha ||
817       pattern_name == kPatternQueryLayerUsePastDistributed || pattern_name == kPatternDistributedAlphaWithLastLayerNorm)
818     return true;
819   return false;
820 }
821 
IsMoe(const std::string & pattern_name) const822 bool EncoderLayerFusion::IsMoe(const std::string &pattern_name) const {
823   if (pattern_name == kPatternSigmaMoeDistributed || pattern_name == kPatternSigmaMoeWithLastLayerNormDistributed ||
824       pattern_name == kPatternSigmaMoe || pattern_name == kPatternSigmaMoeWithLastLayerNorm ||
825       pattern_name == kPatternSigmaQueryLayerMoe || pattern_name == kPatternSigmaQueryLayerDistributedMBGELU ||
826       pattern_name == kPatternSigmaQueryLayerDistributedMoe ||
827       pattern_name == kPatternSigmaQueryLayerDistributedMBMoe ||
828       pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMB ||
829       pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMBGELU ||
830       pattern_name == kPatternSigmaMoeDistributedMB || pattern_name == kPatternSigmaMoeDistributedMBGELU)
831     return true;
832   return false;
833 }
IsFastGeluDistributed(const std::string & pattern_name) const834 bool EncoderLayerFusion::IsFastGeluDistributed(const std::string &pattern_name) const {
835   if (pattern_name == kPatternSigmaDistributed || pattern_name == kPatternSigmaDistributedEmbedding ||
836       pattern_name == kPatternSigmaMoeDistributed || pattern_name == kPatternSigmaWithLastLayerNormDistributed ||
837       pattern_name == kPatternSigmaQueryLayerDistributed || pattern_name == kPatternSigmaQueryLayerDistributedMoe ||
838       pattern_name == kPatternSigmaMoeWithLastLayerNormDistributed ||
839       pattern_name == kPatternSigmaEmbeddingDistributed ||
840       pattern_name == kPatternSigmaWithLastLayerNormDistributedMB ||
841       pattern_name == kPatternSigmaQueryLayerDistributedMB || pattern_name == kPatternSigmaQueryLayerDistributedMBMoe ||
842       pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMB ||
843       pattern_name == kPatternSigmaDistributedEmbeddingMB || pattern_name == kPatternSigmaDistributedMB ||
844       pattern_name == kPatternSigmaMoeDistributedMB)
845     return true;
846   return false;
847 }
848 
IsFastGelu(const std::string & pattern_name) const849 bool EncoderLayerFusion::IsFastGelu(const std::string &pattern_name) const {
850   if (pattern_name == kPatternSigma || pattern_name == kPatternSigmaQuery ||
851       pattern_name == kPatternSigmaWithLastLayerNorm || pattern_name == kPatternSigmaEmbedding ||
852       pattern_name == kPatternSigmaMoe || pattern_name == kPatternSigmaMoeWithLastLayerNorm ||
853       pattern_name == kPatternSigmaWithLastLayerNorm || pattern_name == kPatternSigmaQueryLayerMoe)
854     return true;
855   return false;
856 }
857 
IsQueryLayer(const std::string & pattern_name) const858 bool EncoderLayerFusion::IsQueryLayer(const std::string &pattern_name) const {
859   if (pattern_name == kPatternQueryLayerUsePast || pattern_name == kPatternQueryLayerUsePastDistributed ||
860       pattern_name == kPatternSigmaQueryLayerDistributed || pattern_name == kPatternSigmaQueryLayerMoe ||
861       pattern_name == kPatternSigmaQueryLayerDistributedMoe || pattern_name == kPatternSigmaQueryLayerDistributedMB ||
862       pattern_name == kPatternSigmaQueryLayerDistributedMBMoe ||
863       pattern_name == kPatternSigmaQueryLayerDistributedMBGELU || pattern_name == kPatternSigmaQuery)
864     return true;
865   return false;
866 }
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const867 AnfNodePtr EncoderLayerFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
868                                        const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
869   if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
870     return nullptr;
871   }
872   bool mask = true;
873   is_layernorm_ = IsLastLayerNorm(pattern_name);
874   is_layernorm_fusion_ = IsLayerNormFusion(pattern_name);
875   is_use_past_ = IsUsePast(pattern_name) || IsUsePastMB(pattern_name) || IsUsePastAlpha(pattern_name);
876   is_moe_ = IsMoe(pattern_name);
877   is_fast_gelu_ = IsFastGelu(pattern_name) || IsFastGeluDistributed(pattern_name);
878   if (pattern_name == kPatternSigmaDistributedEmbeddingMB || pattern_name == kPatternSigmaDistributedEmbedding ||
879       pattern_name == kPatternSigmaDistributedEmbeddingMBGELU || pattern_name == kPatternSigmaEmbedding ||
880       pattern_name == kPatternSigmaEmbeddingDistributed)
881     is_embedding_layer_ = true;
882   if (pattern_name == kPatternEncoderLayerT5Pre || pattern_name == kPatternEncoderLayerT5Post ||
883       pattern_name == kPatternEncoderLayerNormT5Pre)
884     is_position_bias_ = true;
885   if (pattern_name == kPatternEncoderLayerPost || pattern_name == kPatternEncoderLayerPostNorm ||
886       pattern_name == kPatternEncoderLayerT5Post)
887     is_post_layernorm_ = true;
888   is_query_layer_ = IsQueryLayer(pattern_name);
889   return CreateMaskedEncoderLayerFusionNode(func_graph, equiv, node, is_post_layernorm_, mask);
890 }
891 
IsActGELU(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const VarPtr & input_prim) const892 bool EncoderLayerFusion::IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
893                                    const VarPtr &input_prim) const {
894   auto act_input = GetAttribute(func_graph, equiv, is_act_);
895   MS_ASSERT(act_input != nullptr);
896   auto act_primitive = ops::GetOperator<ops::Activation>(act_input);
897   MS_CHECK_TRUE_RET(act_primitive != nullptr, false);
898   auto act_primitive_c = act_primitive->GetPrim();
899   if (act_primitive_c->GetAttr(ops::kActivationType) == nullptr ||
900       act_primitive->get_activation_type() != mindspore::GELU) {
901     return false;
902   }
903   return true;
904 }
905 
GetEps(const EquivPtr & equiv,VarPtr node_name,float * eps) const906 STATUS EncoderLayerFusion::GetEps(const EquivPtr &equiv, VarPtr node_name, float *eps) const {
907   if ((*equiv)[node_name] == nullptr || !utils::isa<AnfNodePtr>((*equiv)[node_name])) {
908     MS_LOG(ERROR) << node_name << " is not anfnodeptr";
909     return RET_ERROR;
910   }
911   AnfNodePtr node = utils::cast<AnfNodePtr>((*equiv)[node_name]);
912   MS_ASSERT(node != nullptr);
913   if (utils::isa<ValueNodePtr>(node)) {
914     auto value_ptr_node = utils::cast<ValueNodePtr>(node);
915     auto value_node = utils::cast<ValuePtr>(value_ptr_node->value());
916     if (value_node->isa<tensor::Tensor>()) {
917       auto tensor = value_node->cast<tensor::TensorPtr>();
918       MS_EXCEPTION_IF_NULL(tensor);
919       *eps = *reinterpret_cast<float *>(tensor->data().data());
920       return RET_OK;
921     }
922   }
923   return RET_ERROR;
924 }
925 
GetAttribute(const FuncGraphPtr & func_graph,const EquivPtr & equiv,VarPtr node_name) const926 AnfNodePtr EncoderLayerFusion::GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
927                                             VarPtr node_name) const {
928   if ((*equiv)[node_name] == nullptr || !utils::isa<AnfNodePtr>((*equiv)[node_name])) {
929     MS_LOG(ERROR) << node_name << "is not AnfNodePtr";
930     return nullptr;
931   }
932   AnfNodePtr node = utils::cast<AnfNodePtr>((*equiv)[node_name]);
933   MS_ASSERT(node != nullptr);
934   if (node == nullptr || !utils::isa<CNodePtr>(node)) {
935     auto manager = func_graph->manager();
936     if (manager == nullptr) {
937       return nullptr;
938     }
939     auto users = manager->node_users();
940     auto it = users.find(node);
941     if (it != users.end()) {
942       node = it->second.front().first;
943     }
944     if (node == nullptr || !utils::isa<CNodePtr>(node)) {
945       return nullptr;
946     }
947   }
948   auto cnode = utils::cast<CNodePtr>(node);
949   MS_ASSERT(cnode != nullptr);
950   auto input = cnode->input(0);
951   return input;
952 }
953 
CheckPattern(const FuncGraphPtr & func_graph,const EquivPtr & equiv,int * head_num,int * head_size,float * eps1,float * eps2,float * eps3,float * scale) const954 STATUS EncoderLayerFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num,
955                                         int *head_size, float *eps1, float *eps2, float *eps3, float *scale) const {
956   auto attn_input = GetAttribute(func_graph, equiv, is_attention_);
957   MS_ASSERT(attn_input != nullptr);
958   auto attn_prim = ops::GetOperator<ops::Attention>(attn_input);
959   if (attn_prim->GetAttr(ops::kNumHeads) != nullptr) *head_num = attn_prim->get_head_num();
960   if (attn_prim->GetAttr(ops::kSizePerHead) != nullptr) *head_size = attn_prim->get_head_size();
961   if (attn_prim->GetAttr(ops::kPositionBias1) != nullptr) is_position_bias_ = attn_prim->get_position_bias();
962   if (attn_prim->GetAttr(ops::kScale) != nullptr) *scale = attn_prim->get_scale();
963   if (is_layernorm_fusion_) {
964     auto layrn1_input = GetAttribute(func_graph, equiv, is_layernorm1_);
965     auto layrn1_prim = ops::GetOperator<ops::LayerNormFusion>(layrn1_input);
966     *eps1 = layrn1_prim->get_epsilon();
967     auto layrn2_input = GetAttribute(func_graph, equiv, is_layernorm2_);
968     auto layrn2_prim = ops::GetOperator<ops::LayerNormFusion>(layrn2_input);
969     *eps2 = layrn2_prim->get_epsilon();
970     if (is_layernorm_) {
971       auto layrn3_input = GetAttribute(func_graph, equiv, is_layernorm3_);
972       auto layrn3_prim = ops::GetOperator<ops::LayerNormFusion>(layrn3_input);
973       *eps3 = layrn3_prim->get_epsilon();
974     }
975   } else {
976     MS_CHECK_TRUE_MSG(GetEps(equiv, eps1_, eps1) == RET_OK, RET_ERROR, "not found eps1");
977     MS_CHECK_TRUE_MSG(GetEps(equiv, eps2_, eps2) == RET_OK, RET_ERROR, "not found eps2");
978     if (is_layernorm_) {
979       MS_CHECK_TRUE_MSG(GetEps(equiv, eps3_, eps3) == RET_OK, RET_ERROR, "not found eps3");
980     }
981   }
982   act_type_ = (is_position_bias_) ? (ActType::ActType_Relu)
983                                   : (is_fast_gelu_) ? (ActType::ActType_FastGelu) : (ActType::ActType_Gelu);
984   if (!is_position_bias_ && !is_use_past_ && !is_query_layer_) {
985     if (!IsActGELU(func_graph, equiv, is_act_)) {
986       return RET_ERROR;
987     }
988   }
989   return RET_OK;
990 }
991 
CreatePrim(const FuncGraphPtr & func_graph,const EquivPtr & equiv,int64_t ffn_hidden_size,int64_t expert_num,int64_t expert_offset,float capacity_factor) const992 std::shared_ptr<ops::EncoderLayer> EncoderLayerFusion::CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
993                                                                   int64_t ffn_hidden_size, int64_t expert_num,
994                                                                   int64_t expert_offset, float capacity_factor) const {
995   auto encoder_layer_prim = std::make_shared<ops::EncoderLayer>();
996   if (encoder_layer_prim == nullptr) {
997     MS_LOG(ERROR) << "Build enoder layer primitive failed.";
998     return nullptr;
999   }
1000   int head_num = 0;
1001   int head_size = 0;
1002   float eps1 = 1e-5;
1003   float eps2 = 1e-5;
1004   float eps3 = 1e-5;
1005   float scale = 1.0f;
1006   if (CheckPattern(func_graph, equiv, &head_num, &head_size, &eps1, &eps2, &eps3, &scale)) {
1007     return nullptr;
1008   }
1009   encoder_layer_prim->Init(head_num, head_size, eps1, eps2, eps3, ffn_hidden_size, expert_num, expert_offset,
1010                            capacity_factor, is_position_bias_, is_post_layernorm_, scale, act_type_, is_layernorm_,
1011                            is_use_past_, is_query_layer_, is_moe_, is_embedding_layer_);
1012 
1013   return encoder_layer_prim;
1014 }
1015 
InitAttributes(AnfNodePtr k_past,AnfNodePtr begin_expert_ids,AnfNodePtr weight_m,AnfNodePtr expert_capacity_node,int * ffn_hidden_size,int * expert_num,int * expert_offset,float * capacity_factor) const1016 STATUS EncoderLayerFusion::InitAttributes(AnfNodePtr k_past, AnfNodePtr begin_expert_ids, AnfNodePtr weight_m,
1017                                           AnfNodePtr expert_capacity_node, int *ffn_hidden_size, int *expert_num,
1018                                           int *expert_offset, float *capacity_factor) const {
1019   auto base_shape_ptr = weight_m->Shape();
1020   MS_CHECK_TRUE_RET(base_shape_ptr != nullptr, RET_ERROR);
1021   auto input_shape_ptr = base_shape_ptr->cast<abstract::ShapePtr>();
1022   MS_CHECK_TRUE_RET(input_shape_ptr != nullptr, RET_ERROR);
1023   auto input_shape = input_shape_ptr->shape();
1024   MS_CHECK_TRUE_RET(input_shape.size() >= C2NUM, RET_ERROR);
1025   if (is_moe_) {
1026     auto begin_expert_ids_node = begin_expert_ids->cast<ValueNodePtr>();
1027     *expert_num = (int64_t)input_shape[0];
1028     *expert_offset = CastToInt(begin_expert_ids_node->value())[0];
1029     auto base_shape_k = k_past->Shape();
1030     auto k_shape_ptr = base_shape_k->cast<abstract::ShapePtr>();
1031     auto k_shape = k_shape_ptr->shape();
1032     int seq = static_cast<int>(k_shape[C3NUM]);
1033     auto expert_capacity_value_node = utils::cast<ValuePtr>(utils::cast<ValueNodePtr>(expert_capacity_node)->value());
1034     if (expert_capacity_value_node->isa<tensor::Tensor>()) {
1035       auto tensor = expert_capacity_value_node->cast<tensor::TensorPtr>();
1036       auto expert_capacity = *(reinterpret_cast<float16 *>(tensor->data().data()));
1037       float cast_expert_capacity = Float16::ToFloat32(expert_capacity);
1038       *capacity_factor = (cast_expert_capacity) * (*expert_num) / seq;
1039     }
1040     *ffn_hidden_size = static_cast<int>(input_shape[C2NUM]);
1041   } else {
1042     *ffn_hidden_size = static_cast<int>(input_shape[1]);
1043   }
1044   return RET_OK;
1045 }
1046 
CreateMaskedEncoderLayerFusionNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const AnfNodePtr & node,bool post_layernorm,bool mask=true) const1047 CNodePtr EncoderLayerFusion::CreateMaskedEncoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
1048                                                                 const AnfNodePtr &node, bool post_layernorm,
1049                                                                 bool mask = true) const {
1050   MS_ASSERT(func_graph != nullptr);
1051   MS_ASSERT(equiv != nullptr);
1052   MS_ASSERT(node != nullptr);
1053   AnfNodePtr v_past, k_past, batch_valid_length, embedding_table, expert_ids, begin_expert_ids, expert_capacity_node,
1054     position_ids, embedding_table_input, input_ids, current_index, embedding_table_pos;
1055   auto input = utils::cast<AnfNodePtr>((*equiv)[input_]);
1056   if (is_moe_) {
1057     expert_ids = utils::cast<AnfNodePtr>((*equiv)[expert_ids_]);
1058     begin_expert_ids = utils::cast<AnfNodePtr>((*equiv)[begin_expert_ids_]);
1059     expert_capacity_node = utils::cast<AnfNodePtr>((*equiv)[expert_capacity_]);
1060   }
1061   if (is_use_past_) {
1062     k_past = utils::cast<AnfNodePtr>((*equiv)[k_past_]);
1063     v_past = utils::cast<AnfNodePtr>((*equiv)[v_past_]);
1064   }
1065   auto weight_qkv = utils::cast<AnfNodePtr>((*equiv)[weight_attn_qkv_]);
1066   auto weight_attn_o = utils::cast<AnfNodePtr>((*equiv)[weight_attn_o_]);
1067   auto weight_m = utils::cast<AnfNodePtr>((*equiv)[weight_m_]);
1068   auto weight_p = utils::cast<AnfNodePtr>((*equiv)[weight_p_]);
1069   auto gamma1 = utils::cast<AnfNodePtr>((*equiv)[gamma1_]);
1070   auto gamma2 = utils::cast<AnfNodePtr>((*equiv)[gamma2_]);
1071   AnfNodePtr input_mask = mask ? utils::cast<AnfNodePtr>((*equiv)[mask_]) : nullptr;
1072   int ffn_hidden_size, expert_num = 1, expert_offset = 0;
1073   float capacity_factor = 0;
1074   if (InitAttributes(k_past, begin_expert_ids, weight_m, expert_capacity_node, &ffn_hidden_size, &expert_num,
1075                      &expert_offset, &capacity_factor)) {
1076     MS_LOG(ERROR) << "Init Attributes failed.";
1077     return nullptr;
1078   }
1079   auto encoder_layer_prim = CreatePrim(func_graph, equiv, ffn_hidden_size, expert_num, expert_offset, capacity_factor);
1080   auto encoder_layer_prim_c = encoder_layer_prim->GetPrim();
1081   auto value_node = NewValueNode(encoder_layer_prim_c);
1082   std::vector<AnfNodePtr> new_node_inputs = {value_node, input};
1083   if (is_position_bias_) {
1084     auto position_bias = utils::cast<AnfNodePtr>((*equiv)[position_bias_]);
1085     new_node_inputs.insert(new_node_inputs.end(), {gamma1, weight_qkv});
1086     if (mask) new_node_inputs.push_back(input_mask);
1087     new_node_inputs.insert(new_node_inputs.end(), {position_bias, weight_attn_o, gamma2, weight_m, weight_p});
1088     if (is_layernorm_) {
1089       auto gamma3 = utils::cast<AnfNodePtr>((*equiv)[gamma3_]);
1090       new_node_inputs.push_back(gamma3);
1091     }
1092   } else {
1093     auto bias_attn_qkv = utils::cast<AnfNodePtr>((*equiv)[bias_attn_qkv_]);
1094     auto bias_attn_o = utils::cast<AnfNodePtr>((*equiv)[bias_attn_o_]);
1095     auto bias_m = utils::cast<AnfNodePtr>((*equiv)[bias_m_]);
1096     auto bias_p = utils::cast<AnfNodePtr>((*equiv)[bias_p_]);
1097     auto beta1 = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
1098     auto beta2 = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
1099     if (!is_post_layernorm_) {
1100       if (is_use_past_) new_node_inputs.insert(new_node_inputs.end(), {k_past, v_past});
1101       if (is_query_layer_) {
1102         auto input_q = utils::cast<AnfNodePtr>((*equiv)[input_q_]);
1103         auto weight_q = utils::cast<AnfNodePtr>((*equiv)[weight_attn_q_]);
1104         new_node_inputs.insert(new_node_inputs.end(), {gamma1, beta1, input_q, weight_q, weight_qkv, bias_attn_qkv});
1105       } else {
1106         new_node_inputs.insert(new_node_inputs.end(), {gamma1, beta1, weight_qkv, bias_attn_qkv});
1107       }
1108       if (mask) new_node_inputs.push_back(input_mask);
1109       new_node_inputs.insert(new_node_inputs.end(), {weight_attn_o, bias_attn_o, gamma2, beta2});
1110       if (is_moe_) new_node_inputs.push_back(expert_ids);
1111       new_node_inputs.insert(new_node_inputs.end(), {weight_m, bias_m, weight_p, bias_p});
1112     } else {
1113       new_node_inputs.insert(new_node_inputs.end(), {weight_qkv, bias_attn_qkv});
1114       if (mask) new_node_inputs.push_back(input_mask);
1115       new_node_inputs.insert(new_node_inputs.end(), {weight_attn_o, bias_attn_o, gamma1, beta1, weight_m, bias_m,
1116                                                      weight_p, bias_p, gamma2, beta2});
1117     }
1118     if (is_layernorm_) {
1119       auto beta3 = utils::cast<AnfNodePtr>((*equiv)[beta3_]);
1120       auto gamma3 = utils::cast<AnfNodePtr>((*equiv)[gamma3_]);
1121       new_node_inputs.insert(new_node_inputs.end(), {gamma3, beta3});
1122     }
1123   }
1124   auto inputs = func_graph->get_inputs();
1125   MS_CHECK_TRUE_RET(inputs.size() > C2NUM, nullptr);
1126   if (is_query_layer_) {
1127     embedding_table_pos = utils::cast<AnfNodePtr>((*equiv)[embedding_table_pos_]);
1128     embedding_table = utils::cast<AnfNodePtr>((*equiv)[embedding_table_]);
1129     new_node_inputs.insert(new_node_inputs.end(), {embedding_table, embedding_table_pos, inputs.end()[-2],
1130                                                    inputs.end()[-3], inputs.end()[-1]});
1131   } else if (is_use_past_) {  // temporary solution
1132     if (is_embedding_layer_) {
1133       embedding_table_input = utils::cast<AnfNodePtr>((*equiv)[embedding_table_input_]);
1134       embedding_table_pos = utils::cast<AnfNodePtr>((*equiv)[embedding_table_pos_]);
1135       new_node_inputs.insert(new_node_inputs.end(), {embedding_table_input, embedding_table_pos});
1136     }
1137     new_node_inputs.insert(new_node_inputs.end(), {inputs.end()[-3], inputs.end()[-1]});
1138   }
1139   auto new_node = func_graph->NewCNode(new_node_inputs);
1140   auto old_node = node->cast<CNodePtr>();
1141   MS_CHECK_TRUE_RET(old_node->abstract() != nullptr, nullptr);
1142   new_node->set_abstract(old_node->abstract()->Clone());
1143   new_node->set_fullname_with_scope(node->fullname_with_scope() + "/encoder_layer");
1144   return new_node;
1145 }
1146 }  // namespace mindspore::opt
1147