1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/encoder_layer_fusion.h"
19 #include <functional>
20 #include <utility>
21 #include <vector>
22 #include <algorithm>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/other_ops.h"
25 #include "mindspore/core/ops/nn_optimizer_ops.h"
26 #include "mindspore/core/ops/nn_ops.h"
27 #include "mindspore/core/ops/math_ops.h"
28 #include "mindspore/core/ops/lite_ops.h"
29 #include "mindspore/core/ops/comparison_ops.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "mindspore/core/ops/framework_ops.h"
32 #include "tools/optimizer/common/gllo_utils.h"
33 #include "nnacl/op_base.h"
34 #include "ops/tuple_get_item.h"
35 #include "tools/common/tensor_util.h"
36 #include "ops/op_utils.h"
37
38 namespace mindspore::opt {
39 namespace {
40 const auto &p1 = std::placeholders::_1;
41 } // namespace
42
Init() const43 bool EncoderLayerFusion::Init() const {
44 input_ = std::make_shared<Var>("input");
45 MS_CHECK_TRUE_RET(input_ != nullptr, false);
46 expert_ids_ = std::make_shared<Var>("expert_ids");
47 MS_CHECK_TRUE_RET(expert_ids_ != nullptr, false);
48 expert_capacity_ = std::make_shared<Var>("expert_capacity_");
49 MS_CHECK_TRUE_RET(expert_capacity_ != nullptr, false);
50 begin_expert_ids_ = std::make_shared<Var>("begin_expert_ids_");
51 MS_CHECK_TRUE_RET(begin_expert_ids_ != nullptr, false);
52 beta1_ = std::make_shared<Var>("beta1");
53 MS_CHECK_TRUE_RET(beta1_ != nullptr, false);
54 gamma1_ = std::make_shared<Var>("gamma1");
55 MS_CHECK_TRUE_RET(gamma1_ != nullptr, false);
56 beta2_ = std::make_shared<Var>("beta2");
57 MS_CHECK_TRUE_RET(beta2_ != nullptr, false);
58 gamma2_ = std::make_shared<Var>("gamma2");
59 MS_CHECK_TRUE_RET(gamma2_ != nullptr, false);
60 beta3_ = std::make_shared<Var>("beta3");
61 MS_CHECK_TRUE_RET(beta3_ != nullptr, false);
62 gamma3_ = std::make_shared<Var>("gamma3");
63 MS_CHECK_TRUE_RET(gamma3_ != nullptr, false);
64 weight_attn_qkv_ = std::make_shared<Var>("weight_attn_qkv");
65 MS_CHECK_TRUE_RET(weight_attn_qkv_ != nullptr, false);
66 weight_attn_q_ = std::make_shared<Var>("weight_attn_q_");
67 MS_CHECK_TRUE_RET(weight_attn_q_ != nullptr, false);
68 weight_attn_o_ = std::make_shared<CondVar>(IsParamNode, "weight_attn_o");
69 MS_CHECK_TRUE_RET(weight_attn_o_ != nullptr, false);
70 weight_m_ = std::make_shared<CondVar>(IsParamNode, "weight_m");
71 MS_CHECK_TRUE_RET(weight_m_ != nullptr, false);
72 weight_p_ = std::make_shared<CondVar>(IsParamNode, "weight_p");
73 MS_CHECK_TRUE_RET(weight_p_ != nullptr, false);
74 bias_attn_qkv_ = std::make_shared<Var>("bias_attn_qkv");
75 MS_CHECK_TRUE_RET(bias_attn_qkv_ != nullptr, false);
76 bias_attn_o_ = std::make_shared<CondVar>(IsParamNode, "bias_attn_o");
77 MS_CHECK_TRUE_RET(bias_attn_o_ != nullptr, false);
78 bias_m_ = std::make_shared<CondVar>(IsParamNode, "bias_m");
79 MS_CHECK_TRUE_RET(bias_m_ != nullptr, false);
80 bias_p_ = std::make_shared<CondVar>(IsParamNode, "bias_p");
81 MS_CHECK_TRUE_RET(bias_p_ != nullptr, false);
82 mask_ = std::make_shared<Var>("mask");
83 MS_CHECK_TRUE_RET(mask_ != nullptr, false);
84 is_attention_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAttention), "is_attention");
85 MS_CHECK_TRUE_RET(is_attention_ != nullptr, false);
86 is_layernorm1_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm1");
87 MS_CHECK_TRUE_RET(is_layernorm1_ != nullptr, false);
88 is_layernorm2_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm2");
89 MS_CHECK_TRUE_RET(is_layernorm2_ != nullptr, false);
90 position_bias_ = std::make_shared<Var>("position_bias");
91 MS_CHECK_TRUE_RET(position_bias_ != nullptr, false);
92 is_act_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "activation");
93 MS_CHECK_TRUE_RET(is_act_ != nullptr, {});
94 eps1_ = std::make_shared<Var>("eps1_");
95 MS_CHECK_TRUE_RET(eps1_ != nullptr, false);
96 eps2_ = std::make_shared<Var>("eps2_");
97 MS_CHECK_TRUE_RET(eps2_ != nullptr, false);
98 eps3_ = std::make_shared<Var>("eps3_");
99 MS_CHECK_TRUE_RET(eps3_ != nullptr, false);
100 batch_valid_length_ = std::make_shared<Var>("batch_valid_length");
101 MS_CHECK_TRUE_RET(batch_valid_length_ != nullptr, false);
102 k_past_ = std::make_shared<Var>("k_past");
103 MS_CHECK_TRUE_RET(k_past_ != nullptr, false);
104 v_past_ = std::make_shared<Var>("k_past");
105 MS_CHECK_TRUE_RET(v_past_ != nullptr, false);
106 input_q_ = std::make_shared<Var>("input_q");
107 MS_CHECK_TRUE_RET(input_q_ != nullptr, false);
108 embedding_table_ = std::make_shared<Var>("embedding_table");
109 MS_CHECK_TRUE_RET(embedding_table_ != nullptr, false);
110 is_layernorm3_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm2");
111 MS_CHECK_TRUE_RET(is_layernorm3_ != nullptr, false);
112 position_ids_ = std::make_shared<Var>("position_ids_");
113 MS_CHECK_TRUE_RET(position_ids_ != nullptr, false);
114 embedding_table_input_ = std::make_shared<Var>("embedding_table_input_");
115 MS_CHECK_TRUE_RET(embedding_table_input_ != nullptr, false);
116 current_index_ = std::make_shared<Var>("current_index_");
117 MS_CHECK_TRUE_RET(current_index_ != nullptr, false);
118 embedding_table_pos_ = std::make_shared<Var>("embedding_table_pos_");
119 MS_CHECK_TRUE_RET(embedding_table_pos_ != nullptr, false);
120 return true;
121 }
122
getTuple(bool post_layernorm,bool layernorm_fusion=false,bool is_position_bias=false) const123 VectorRef EncoderLayerFusion::getTuple(bool post_layernorm, bool layernorm_fusion = false,
124 bool is_position_bias = false) const {
125 auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
126 MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
127 auto var1 = std::make_shared<Var>("var1-reshape");
128 MS_CHECK_TRUE_RET(var1 != nullptr, {});
129 auto reshape1 = VectorRef({is_reshape1, input_, var1});
130 if (post_layernorm && !is_position_bias) {
131 return reshape1;
132 }
133 if (!layernorm_fusion) {
134 return DefineLayerNorm(is_position_bias, reshape1, gamma1_, beta1_, eps1_);
135 }
136 auto layer_norm = VectorRef({is_layernorm1_, reshape1, gamma1_, beta1_});
137 auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
138 auto var_tuple = std::make_shared<Var>("var_tuple");
139 auto tuple = VectorRef({is_tuple, layer_norm, var_tuple});
140 return tuple;
141 }
142
DefineLayerNorm(bool is_position_bias,BaseRef input,VarPtr gamma,VarPtr beta,VarPtr eps) const143 VectorRef EncoderLayerFusion::DefineLayerNorm(bool is_position_bias, BaseRef input, VarPtr gamma, VarPtr beta,
144 VarPtr eps) const {
145 auto var1 = std::make_shared<Var>("var1");
146 MS_CHECK_TRUE_RET(var1 != nullptr, {});
147 auto is_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "reduce");
148 MS_CHECK_TRUE_RET(is_reduce != nullptr, {});
149 auto reduce1 = VectorRef({is_reduce, input, var1});
150 auto is_sub = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSubFusion), "sub-f");
151 MS_CHECK_TRUE_RET(is_sub != nullptr, {});
152 auto sub = VectorRef({is_sub, input, reduce1});
153 auto is_sqr = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSquare), "sqr");
154 MS_CHECK_TRUE_RET(is_sqr != nullptr, {});
155 auto sqr = (is_position_bias) ? VectorRef({is_sqr, input}) : VectorRef({is_sqr, sub});
156 auto var2 = std::make_shared<Var>("var2");
157 MS_CHECK_TRUE_RET(var2 != nullptr, {});
158 auto is_reduce2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "reduce2");
159 MS_CHECK_TRUE_RET(is_reduce2 != nullptr, {});
160 auto reduce2 = VectorRef({is_reduce2, sqr, var2});
161 auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is-add");
162 MS_CHECK_TRUE_RET(is_add != nullptr, {});
163 auto add = VectorRef({is_add, reduce2, eps});
164 auto is_sqr2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSqrt), "sqr2");
165 MS_CHECK_TRUE_RET(is_sqr2 != nullptr, {});
166 auto sqr2 = VectorRef({is_sqr2, add});
167 auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "real-div");
168 MS_CHECK_TRUE_RET(is_div != nullptr, {});
169 auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "mul");
170 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
171 if (is_position_bias) {
172 auto real_div = VectorRef({is_div, input, sqr2});
173 auto mul = VectorRef({is_mul, real_div, gamma});
174 return mul;
175 }
176 auto real_div = VectorRef({is_div, sub, sqr2});
177 auto mul = VectorRef({is_mul, real_div, gamma});
178 auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is-add");
179 MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
180 auto add2 = VectorRef({is_add2, mul, beta});
181 return add2;
182 }
183
DefinePatternInitReset(VectorRef input,bool value_reset,bool key_reset) const184 VectorRef EncoderLayerFusion::DefinePatternInitReset(VectorRef input, bool value_reset, bool key_reset) const {
185 auto is_cast = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimCast), "cast_init");
186 auto var_cast = std::make_shared<Var>("var_cast");
187 auto cast = VectorRef({is_cast, init_reset_, var_cast});
188 auto is_mul_k = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "mul_k");
189 auto mul_k = VectorRef({is_mul_k, k_past_, cast});
190 auto is_assign_k = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAssign), "assign_k");
191 auto var_assign_k = std::make_shared<Var>("var_assign");
192 auto assign_k = VectorRef({is_assign_k, k_past_, mul_k, var_assign_k});
193 if (key_reset) return assign_k;
194 auto is_depend_k = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend_k");
195 auto depend_k = VectorRef({is_depend_k, input, assign_k});
196 auto is_mul_v = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "mul_v");
197 auto mul_v = VectorRef({is_mul_v, v_past_, cast});
198 auto is_assign_v = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAssign), "assign_v");
199 auto var_assign_v = std::make_shared<Var>("var_assign");
200 auto assign_v = VectorRef({is_assign_v, v_past_, mul_v, var_assign_v});
201 if (value_reset) return assign_v;
202 auto is_depend_kv = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend_kv");
203 auto depend_kv = VectorRef({is_depend_kv, depend_k, assign_v});
204 return depend_kv;
205 }
206
DefineBatchValidLength(const BaseRef & input) const207 BaseRef EncoderLayerFusion::DefineBatchValidLength(const BaseRef &input) const {
208 auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshpae");
209 MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
210 auto var = std::make_shared<Var>("var");
211 MS_CHECK_TRUE_RET(var != nullptr, {});
212 auto reshape = VectorRef({is_reshape, batch_valid_length_, var});
213 auto var2 = std::make_shared<Var>("var2");
214 MS_CHECK_TRUE_RET(var2 != nullptr, {});
215 auto is_less = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLessEqual), "is_less_eq");
216 MS_CHECK_TRUE_RET(is_less != nullptr, {});
217 auto less = VectorRef({is_less, var2, reshape});
218 auto var3 = std::make_shared<Var>("var3");
219 MS_CHECK_TRUE_RET(var3 != nullptr, {});
220 auto is_cast = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimCast), "is_cast");
221 MS_CHECK_TRUE_RET(is_cast != nullptr, {});
222 auto cast = VectorRef({is_cast, less, var3});
223 auto var4 = std::make_shared<Var>("var4");
224 MS_CHECK_TRUE_RET(var4 != nullptr, {});
225 auto is_expand_dims = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims), "is_expand_dims");
226 MS_CHECK_TRUE_RET(is_expand_dims != nullptr, {});
227 auto expand_dims = VectorRef({is_expand_dims, cast, var4});
228 auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "is_mul");
229 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
230 auto mul = VectorRef({is_mul, input, expand_dims});
231 return mul;
232 }
233
DefinePatternMoETopKRouter(VectorRef input) const234 VectorRef EncoderLayerFusion::DefinePatternMoETopKRouter(VectorRef input) const {
235 auto var_onehot1 = std::make_shared<Var>("var_onehot1");
236 auto var_onehot2 = std::make_shared<Var>("var_onehot2");
237 auto var_onehot3 = std::make_shared<Var>("var_onehot3");
238 auto var_onehot4 = std::make_shared<Var>("var_onehot4");
239 auto is_onehot = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimOneHot), "is_onehot");
240 MS_CHECK_TRUE_RET(is_onehot != nullptr, {});
241 auto onehot = VectorRef({is_onehot, input, var_onehot1, var_onehot2, var_onehot3});
242 auto is_cumsum = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimCumSum), "is_transpose1");
243 MS_CHECK_TRUE_RET(is_cumsum != nullptr, {});
244 auto var_cumsum = std::make_shared<Var>("var_cumsum");
245 MS_CHECK_TRUE_RET(var_cumsum != nullptr, {});
246 auto cumsum = VectorRef({is_cumsum, onehot, var_cumsum});
247 auto is_mul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "is_mul1");
248 MS_CHECK_TRUE_RET(is_mul1 != nullptr, {});
249 auto mul_fusion1 = VectorRef({is_mul1, cumsum, onehot});
250 auto is_less = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLess), "is_less");
251 MS_CHECK_TRUE_RET(is_less != nullptr, {});
252 auto less = VectorRef({is_less, mul_fusion1, expert_capacity_});
253 auto is_cast3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimCast), "is_cast3");
254 MS_CHECK_TRUE_RET(is_cast3 != nullptr, {});
255 auto var_cast3 = std::make_shared<Var>("var_cast3");
256 auto cast3 = VectorRef({is_cast3, less, var_cast3});
257 auto is_mul4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "is_mul4");
258 MS_CHECK_TRUE_RET(is_mul4 != nullptr, {});
259 auto mul_fusion4 = VectorRef({is_mul4, cast3, onehot});
260 auto is_reduce5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "is_reduce5");
261 MS_CHECK_TRUE_RET(is_reduce5 != nullptr, {});
262 auto var_reduce5 = std::make_shared<Var>("var_reduce5");
263 auto reduce5 = VectorRef({is_reduce5, mul_fusion4, var_reduce5});
264 auto is_expand_dims1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims), "is_expand_dims1");
265 MS_CHECK_TRUE_RET(is_expand_dims1 != nullptr, {});
266 auto var_expand_dims = std::make_shared<Var>("var_expand_dims1");
267 auto expand_dims1 = VectorRef({is_expand_dims1, reduce5, var_expand_dims});
268 auto is_mul7 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "is_mul7");
269 MS_CHECK_TRUE_RET(is_mul7 != nullptr, {});
270 auto mul_fusion7 = VectorRef({is_mul7, expand_dims1, onehot});
271 auto is_onehot2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimOneHot), "is_onehot2");
272 MS_CHECK_TRUE_RET(is_onehot2 != nullptr, {});
273 auto is_cast2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimCast), "is_cast2");
274 MS_CHECK_TRUE_RET(is_cast2 != nullptr, {});
275 auto var_cast2 = std::make_shared<Var>("var_cast2");
276 auto cast2 = VectorRef({is_cast2, mul_fusion1, var_cast2});
277 auto onehot2 = VectorRef({is_onehot2, cast2, var_onehot4, var_onehot2, var_onehot3});
278 auto is_expand_dims2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimExpandDims), "is_expand_dims2");
279 MS_CHECK_TRUE_RET(is_expand_dims2 != nullptr, {});
280 auto expand_dims2 = VectorRef({is_expand_dims2, mul_fusion7, var_expand_dims});
281 auto is_mul8 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "is_mul8");
282 MS_CHECK_TRUE_RET(is_mul8 != nullptr, {});
283 auto mul_fusion8 = VectorRef({is_mul8, expand_dims2, onehot2});
284 return mul_fusion8;
285 }
286
DefinePatternMoERouter(VectorRef input_layernorm) const287 VectorRef EncoderLayerFusion::DefinePatternMoERouter(VectorRef input_layernorm) const {
288 auto is_relu = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "depend");
289 auto relu = VectorRef({is_relu, expert_ids_});
290 auto is_stride_slice = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimStridedSlice), "is_stride_slice");
291 MS_CHECK_TRUE_RET(is_stride_slice != nullptr, {});
292 auto var_stride_slice1 = std::make_shared<Var>("var_stride_slice1");
293 MS_CHECK_TRUE_RET(var_stride_slice1 != nullptr, {});
294 auto var_stride_slice2 = std::make_shared<Var>("var_stride_slice2");
295 MS_CHECK_TRUE_RET(var_stride_slice2 != nullptr, {});
296 auto var_stride_slice3 = std::make_shared<Var>("var_stride_slice3");
297 MS_CHECK_TRUE_RET(var_stride_slice3 != nullptr, {});
298 auto strid_slice = VectorRef({is_stride_slice, relu, begin_expert_ids_, var_stride_slice2, var_stride_slice3});
299 auto is_depend = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
300 auto depend = VectorRef({is_depend, strid_slice, input_layernorm});
301 auto is_reshape_router = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "router-reshpae");
302 MS_CHECK_TRUE_RET(is_reshape_router != nullptr, {});
303 auto var1 = std::make_shared<Var>("var1");
304 MS_CHECK_TRUE_RET(var1 != nullptr, {});
305 auto router_reshape = VectorRef({is_reshape_router, depend, var1});
306 return DefinePatternMoETopKRouter(router_reshape);
307 }
308
DefinePatternMoEFfn(VectorRef input_reshape,bool gelu=false) const309 VectorRef EncoderLayerFusion::DefinePatternMoEFfn(VectorRef input_reshape, bool gelu = false) const {
310 auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
311 auto matmul1 = VectorRef({is_matmul1, input_reshape, weight_m_});
312 auto is_add1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add1");
313 MS_CHECK_TRUE_RET(is_add1 != nullptr, {});
314 auto add = VectorRef({is_add1, matmul1, bias_m_});
315 auto is_act = (gelu) ? std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "is_Gelu")
316 : std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimFastGeLU), "is_FastGelu");
317 auto act = VectorRef({is_act, add});
318 auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
319 MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
320 auto matmul2 = VectorRef({is_matmul2, act, weight_p_});
321 auto is_all_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAllReduce), "is_all_reduce");
322 auto all_reduce = VectorRef({is_all_reduce, matmul2});
323 auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add2");
324 MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
325 auto add2 = VectorRef({is_add2, all_reduce, bias_p_});
326 return add2;
327 }
328
DefinePatternMoE(VectorRef input_layernorm,bool multi_batch=false,bool gelu=false) const329 VectorRef EncoderLayerFusion::DefinePatternMoE(VectorRef input_layernorm, bool multi_batch = false,
330 bool gelu = false) const {
331 auto router_output = DefinePatternMoERouter(input_layernorm);
332 auto is_reshape10 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape1");
333 MS_CHECK_TRUE_RET(is_reshape10 != nullptr, {});
334 auto var_reshape10 = std::make_shared<Var>("var_reshape10");
335 MS_CHECK_TRUE_RET(var_reshape10 != nullptr, {});
336 auto reshape10 = VectorRef({is_reshape10, router_output, var_reshape10});
337 auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul168");
338 MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
339 auto matmul1 = VectorRef({is_matmul1, input_layernorm, reshape10});
340 auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape3");
341 MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
342 auto var_reshape3 = std::make_shared<Var>("var_reshape3");
343 MS_CHECK_TRUE_RET(var_reshape3 != nullptr, {});
344 auto reshape3 = VectorRef({is_reshape3, matmul1, var_reshape3});
345 auto is_transpose1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "is_transpose1");
346 MS_CHECK_TRUE_RET(is_transpose1 != nullptr, {});
347 auto var_transpose1 = std::make_shared<Var>("var_transpose1");
348 MS_CHECK_TRUE_RET(var_reshape3 != nullptr, {});
349 auto transpose1 = VectorRef({is_transpose1, reshape3, var_transpose1});
350 auto is_reshape4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape4");
351 MS_CHECK_TRUE_RET(is_reshape4 != nullptr, {});
352 auto var_reshape4 = std::make_shared<Var>("var_reshape4");
353 MS_CHECK_TRUE_RET(var_reshape4 != nullptr, {});
354 auto reshape4 = VectorRef({is_reshape4, transpose1, var_reshape4});
355 auto ffn_output = DefinePatternMoEFfn(reshape4, gelu);
356 auto is_reshape6 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape6");
357 MS_CHECK_TRUE_RET(is_reshape6 != nullptr, {});
358 auto var_reshape6 = std::make_shared<Var>("var_reshape6");
359 MS_CHECK_TRUE_RET(var_reshape6 != nullptr, {});
360 auto reshape6 = VectorRef({is_reshape6, ffn_output, var_reshape6});
361 VectorRef transpose2;
362 if (multi_batch) {
363 auto is_transpose2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTranspose), "is_transpose2");
364 MS_CHECK_TRUE_RET(is_transpose2 != nullptr, {});
365 auto var_transpose2 = std::make_shared<Var>("var_transpose2");
366 MS_CHECK_TRUE_RET(var_transpose2 != nullptr, {});
367 transpose2 = VectorRef({is_transpose2, reshape6, var_transpose2});
368 } else {
369 transpose2 = reshape6;
370 }
371 auto is_reshape8 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "is_reshape8");
372 MS_CHECK_TRUE_RET(is_reshape8 != nullptr, {});
373 auto var_reshape8 = std::make_shared<Var>("var_reshape8");
374 MS_CHECK_TRUE_RET(var_reshape8 != nullptr, {});
375 auto reshape8 = VectorRef({is_reshape8, transpose2, var_reshape8});
376 auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "matmul2");
377 MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
378 auto matmul2 = VectorRef({is_matmul2, reshape10, reshape8});
379 return matmul2;
380 }
381
DefineDependKV(VectorRef input,VectorRef deppend_v_input,bool moe=false) const382 VectorRef EncoderLayerFusion::DefineDependKV(VectorRef input, VectorRef deppend_v_input, bool moe = false) const {
383 auto is_tuple_k = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
384 auto var_tuple_k = std::make_shared<Var>("var_tuple");
385 auto tuple_k = VectorRef({is_tuple_k, input, var_tuple_k});
386 auto is_assign_k = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAssign), "assign_k_fuse");
387 auto var_assign_k = std::make_shared<Var>("var_assign_k");
388 auto assign_k = VectorRef({is_assign_k, k_past_, DefineBatchValidLength(tuple_k), var_assign_k});
389 auto is_tuple_v = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
390 auto var_tuple_v = std::make_shared<Var>("var_tuple");
391 auto tuple_v = VectorRef({is_tuple_v, input, var_tuple_v});
392 auto is_assign_v = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAssign), "assign_k_fuse");
393 auto var_assign_v = std::make_shared<Var>("var_assign_v");
394 auto assign_v = VectorRef({is_assign_v, v_past_, DefineBatchValidLength(tuple_v), var_assign_v});
395 auto is_depend_v_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend_kv");
396 auto depend_v_mul = VectorRef({is_depend_v_mul, deppend_v_input, assign_k});
397 auto is_depend_kv_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend_kv");
398 return VectorRef({is_depend_kv_mul, depend_v_mul, assign_v});
399 }
400
DefinePatternSigmaFfn(BaseRef input,bool gelu=false,bool distributed=false) const401 VectorRef EncoderLayerFusion::DefinePatternSigmaFfn(BaseRef input, bool gelu = false, bool distributed = false) const {
402 auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder2");
403 MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
404 auto var2 = std::make_shared<Var>("var2");
405 MS_CHECK_TRUE_RET(var2 != nullptr, {});
406 auto reshape2 = VectorRef({is_reshape2, input, var2});
407 auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
408 auto matmul1 = VectorRef({is_matmul1, reshape2, weight_m_, bias_m_});
409 auto is_act = (gelu) ? std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "is_Gelu")
410 : std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimFastGeLU), "is_FastGelu");
411 MS_CHECK_TRUE_RET(is_act != nullptr, {});
412 auto act = VectorRef({is_act, matmul1});
413 auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
414 MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
415 auto matmul2 = VectorRef({is_matmul2, act, weight_p_});
416 if (!distributed) matmul2.push_back(bias_p_);
417 auto is_all_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAllReduce), "is_all_reduce");
418 auto all_reduce = VectorRef({is_all_reduce, matmul2});
419 auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add2");
420 auto add2 = VectorRef({is_add2, all_reduce, bias_p_});
421 auto is_reshape5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder5");
422 MS_CHECK_TRUE_RET(is_reshape5 != nullptr, {});
423 auto var5 = std::make_shared<Var>("var5");
424 MS_CHECK_TRUE_RET(var5 != nullptr, {});
425 auto reshape5 = (distributed) ? VectorRef({is_reshape5, add2, var5}) : VectorRef({is_reshape5, matmul2, var5});
426 return reshape5;
427 }
428
DefinePatternEncoderSigma(bool moe=false,bool use_past=true,bool distributed=false,bool is_layer_norm=false,bool query_layer=false,bool multi_batch=false,bool first_encoder=false,bool gelu=false) const429 VectorRef EncoderLayerFusion::DefinePatternEncoderSigma(bool moe = false, bool use_past = true,
430 bool distributed = false, bool is_layer_norm = false,
431 bool query_layer = false, bool multi_batch = false,
432 bool first_encoder = false, bool gelu = false) const {
433 VectorRef input, q_input;
434 if (first_encoder) {
435 input = DefineFirstEncoder(false);
436 }
437 auto is_reshape = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
438 MS_CHECK_TRUE_RET(is_reshape != nullptr, {});
439 auto var = std::make_shared<Var>("var");
440 MS_CHECK_TRUE_RET(var != nullptr, {});
441 if (query_layer) {
442 auto var_gather_emb = std::make_shared<Var>("var gather_emb");
443 MS_CHECK_TRUE_RET(var_gather_emb != nullptr, {});
444 auto is_gather_emb = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimGather), "is_gather_emb");
445 MS_CHECK_TRUE_RET(is_gather_emb != nullptr, {});
446 auto gather_emb = VectorRef({is_gather_emb, embedding_table_pos_, input_q_, var_gather_emb});
447 auto is_reshape_emb = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder7");
448 MS_CHECK_TRUE_RET(is_reshape_emb != nullptr, {});
449 auto var_reshape_emb = std::make_shared<Var>("var reshape_emb");
450 MS_CHECK_TRUE_RET(var_reshape_emb != nullptr, {});
451 q_input = VectorRef({is_reshape_emb, gather_emb, var_reshape_emb});
452 }
453 auto reshape = (first_encoder) ? VectorRef({is_reshape, DefineLayerNorm(false, input, gamma1_, beta1_, eps1_), var})
454 : VectorRef({is_reshape, DefineLayerNorm(false, input_, gamma1_, beta1_, eps1_), var});
455 auto attention = (query_layer) ? VectorRef({is_attention_, q_input, reshape, reshape, weight_attn_q_,
456 weight_attn_qkv_, weight_attn_o_, bias_attn_qkv_, bias_attn_o_, mask_})
457 : VectorRef({is_attention_, reshape, reshape, reshape, weight_attn_qkv_,
458 weight_attn_o_, bias_attn_qkv_, bias_attn_o_, mask_});
459
460 auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
461
462 auto var_tuple = std::make_shared<Var>("var_tuple");
463 auto tuple = VectorRef({is_tuple, attention, var_tuple});
464 auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
465 auto add = (first_encoder) ? VectorRef({is_add, input, tuple}) : VectorRef({is_add, input_, tuple});
466 auto layer_norm2 = DefineLayerNorm(false, add, gamma2_, beta2_, eps2_);
467 auto ffn_output =
468 (moe) ? DefinePatternMoE(layer_norm2, multi_batch, gelu) : DefinePatternSigmaFfn(layer_norm2, gelu, distributed);
469 auto depend_kv_mul = DefineDependKV(attention, ffn_output, moe);
470
471 auto is_add3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add3");
472 auto add3 = VectorRef({is_add3, add, depend_kv_mul});
473 if (is_layer_norm) return DefineLayerNorm(false, add3, gamma3_, beta3_, eps3_);
474 if (query_layer) {
475 auto is_reshape7 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder7");
476 MS_CHECK_TRUE_RET(is_reshape7 != nullptr, {});
477 auto var7 = std::make_shared<Var>("var7");
478 MS_CHECK_TRUE_RET(var7 != nullptr, {});
479 auto reshape7 = VectorRef({is_reshape7, add3, var7});
480 add3 = reshape7;
481 auto is_gather = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimGather), "is_gather");
482 MS_CHECK_TRUE_RET(is_gather != nullptr, {});
483 auto var_gather2 = std::make_shared<Var>("var_gather2");
484 auto gather = VectorRef({is_gather, add3, current_index_, var_gather2});
485 auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul3");
486 MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
487 auto matmul3 = VectorRef({is_matmul3, gather, embedding_table_});
488 return matmul3;
489 }
490 return add3;
491 }
492
DefinePatternEncoderAlpha(bool moe=false,bool distributed=false,bool is_layer_norm=false,bool query_layer=false,bool use_past=false) const493 VectorRef EncoderLayerFusion::DefinePatternEncoderAlpha(bool moe = false, bool distributed = false,
494 bool is_layer_norm = false, bool query_layer = false,
495 bool use_past = false) const {
496 VectorRef add2;
497 auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
498 MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
499 auto var1 = std::make_shared<Var>("var1-reshape");
500 MS_CHECK_TRUE_RET(var1 != nullptr, {});
501 auto reshape1 = VectorRef({is_reshape1, input_, var1});
502 VectorRef layer_norm = (query_layer) ? VectorRef({is_layernorm1_, input_, gamma1_, beta1_})
503 : VectorRef({is_layernorm1_, reshape1, gamma1_, beta1_});
504 auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
505 auto var_tuple = std::make_shared<Var>("var_tuple");
506 auto tuple = VectorRef({is_tuple, layer_norm, var_tuple});
507 auto is_deppend = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
508 auto deppend = VectorRef({is_deppend, tuple, k_past_});
509 auto is_deppend2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
510 auto deppend2 = VectorRef({is_deppend2, deppend, v_past_});
511 auto attention = (query_layer) ? VectorRef({is_attention_, input_q_, DefinePatternInitReset(tuple, false, false),
512 DefinePatternInitReset(tuple, false, false), weight_attn_q_,
513 weight_attn_qkv_, weight_attn_o_, bias_attn_qkv_, bias_attn_o_, mask_})
514 : VectorRef({is_attention_, deppend2, deppend2, deppend2, weight_attn_qkv_,
515 weight_attn_o_, bias_attn_qkv_, bias_attn_o_, mask_});
516 auto is_tuple1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
517 auto var_tuple1 = std::make_shared<Var>("var_tuple1");
518 auto tuple1 = VectorRef({is_tuple1, attention, var_tuple1});
519 auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
520 auto add = (query_layer) ? VectorRef({is_add, input_, tuple1}) : VectorRef({is_add, reshape1, tuple1});
521 auto layer_norm2 = VectorRef({is_layernorm2_, add, gamma2_, beta2_});
522 auto is_tuple2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item2");
523 auto var_tuple2 = std::make_shared<Var>("var_tuple2");
524 auto tuple2 = VectorRef({is_tuple2, layer_norm2, var_tuple2});
525 auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
526 auto matmul1 = VectorRef({is_matmul1, tuple2, weight_m_, bias_m_});
527 auto is_act = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "is_FastGelu");
528 MS_CHECK_TRUE_RET(is_act != nullptr, {});
529 auto act = VectorRef({is_act, matmul1});
530 auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
531 auto matmul2 =
532 (distributed) ? VectorRef({is_matmul2, act, weight_p_}) : VectorRef({is_matmul2, act, weight_p_, bias_p_});
533 if (distributed) {
534 auto is_all_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAllReduce), "is_all_reduce");
535 auto all_reduce = VectorRef({is_all_reduce, matmul2});
536 auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add2");
537 add2 = VectorRef({is_add2, all_reduce, bias_p_});
538 }
539 if (use_past && !query_layer) {
540 auto is_deppend3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend3");
541 auto deppend3 = VectorRef({is_deppend3, v_past_, v_past_});
542 auto is_deppend4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
543 auto deppend4 =
544 (distributed) ? VectorRef({is_deppend4, add2, deppend3}) : VectorRef({is_deppend4, matmul2, deppend3});
545 auto is_deppend5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
546 auto deppend5 = VectorRef({is_deppend5, k_past_, k_past_});
547 auto is_deppend6 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
548 matmul2 = VectorRef({is_deppend6, deppend4, deppend5});
549 }
550 auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder2");
551 MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
552 auto var2 = std::make_shared<Var>("var2");
553 MS_CHECK_TRUE_RET(var2 != nullptr, {});
554 auto reshape2 = VectorRef({is_reshape2, add, var2});
555 auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder3");
556 MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
557 auto var3 = std::make_shared<Var>("var3");
558 MS_CHECK_TRUE_RET(var3 != nullptr, {});
559 auto reshape3 = VectorRef({is_reshape3, matmul2, var3});
560 auto depend_kv_mul =
561 (distributed) ? DefineDependKV(attention, add2, true) : DefineDependKV(attention, matmul2, false);
562 auto is_add3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add3");
563 auto add3 = (query_layer) ? VectorRef({is_add3, add, depend_kv_mul}) : VectorRef({is_add3, reshape2, reshape3});
564 if (is_layer_norm) {
565 auto is_reshape_norm = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder3");
566 MS_CHECK_TRUE_RET(is_reshape_norm != nullptr, {});
567 auto var_norm = std::make_shared<Var>("var3");
568 MS_CHECK_TRUE_RET(var_norm != nullptr, {});
569 auto reshape_norm = VectorRef({is_reshape_norm, add3, var_norm});
570 auto layer_norm3 = VectorRef({is_layernorm3_, reshape_norm, gamma3_, beta3_});
571 return layer_norm3;
572 }
573 if (query_layer) {
574 auto is_matmul3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul3");
575 MS_CHECK_TRUE_RET(is_matmul3 != nullptr, {});
576 auto matmul3 = VectorRef({is_matmul3, add3, embedding_table_});
577 return matmul3;
578 }
579 return add3;
580 }
581
DefineFirstEncoder(bool distributed=false) const582 VectorRef EncoderLayerFusion::DefineFirstEncoder(bool distributed = false) const {
583 auto is_depend = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimDepend), "depend");
584 auto depend = VectorRef({is_depend, input_, expert_ids_});
585 auto var1 = std::make_shared<Var>("var1");
586 MS_CHECK_TRUE_RET(var1 != nullptr, {});
587 auto is_gather = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimGather), "is_gather");
588 MS_CHECK_TRUE_RET(is_gather != nullptr, {});
589 auto gather = (distributed) ? VectorRef({is_gather, embedding_table_input_, depend, var1})
590 : VectorRef({is_gather, embedding_table_input_, input_, var1});
591 auto is_gather2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimGather), "is_gather2");
592 MS_CHECK_TRUE_RET(is_gather2 != nullptr, {});
593 auto gather2 = VectorRef({is_gather2, embedding_table_pos_, position_ids_, var1});
594 auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
595 MS_CHECK_TRUE_RET(is_add != nullptr, {});
596 auto add = VectorRef({is_add, gather, gather2});
597 return add;
598 }
DefinePatternEncoderLayer(bool post_layernorm=true,bool layernorm_fusion=false,bool is_position_bias=false,bool mask=true,bool is_layer_norm=false) const599 VectorRef EncoderLayerFusion::DefinePatternEncoderLayer(bool post_layernorm = true, bool layernorm_fusion = false,
600 bool is_position_bias = false, bool mask = true,
601 bool is_layer_norm = false) const {
602 VectorRef tuple, tuple2, tuple3, reshape2, matmul1, inputs, layer_norm2;
603 auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
604 MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
605 auto var1 = std::make_shared<Var>("var1");
606 MS_CHECK_TRUE_RET(var1 != nullptr, {});
607 auto reshape1 = VectorRef({is_reshape1, input_, var1});
608 if (!is_position_bias) {
609 inputs = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
610 getTuple(post_layernorm, layernorm_fusion, is_position_bias),
611 getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_, weight_attn_o_,
612 bias_attn_qkv_, bias_attn_o_});
613 } else {
614 inputs = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
615 getTuple(post_layernorm, layernorm_fusion, is_position_bias),
616 getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_, weight_attn_o_,
617 position_bias_});
618 }
619 if (mask) inputs.push_back(mask_);
620 auto attention = VectorRef(inputs);
621 if (!is_position_bias) {
622 auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
623 auto var_tuple = std::make_shared<Var>("var_tuple");
624 tuple = VectorRef({is_tuple, attention, var_tuple});
625 } else {
626 tuple = attention;
627 }
628 auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add");
629 auto add = (is_position_bias && post_layernorm)
630 ? VectorRef({is_add, getTuple(post_layernorm, layernorm_fusion, is_position_bias), tuple})
631 : VectorRef({is_add, reshape1, tuple});
632 auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder2");
633 MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
634 auto var2 = std::make_shared<Var>("var2");
635 MS_CHECK_TRUE_RET(var2 != nullptr, {});
636 if (layernorm_fusion) {
637 layer_norm2 = VectorRef({is_layernorm2_, add, gamma2_, beta2_});
638 auto is_tuple2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item2");
639 auto var_tuple2 = std::make_shared<Var>("var_tuple2");
640 tuple2 = VectorRef({is_tuple2, layer_norm2, var_tuple2});
641 } else {
642 tuple2 = DefineLayerNorm(is_position_bias, add, gamma2_, beta2_, eps2_);
643 }
644 auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
645 MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
646 if (is_position_bias) {
647 reshape2 = (post_layernorm) ? VectorRef({is_reshape2, tuple2, var2}) : VectorRef({is_reshape2, add, var2});
648 matmul1 = VectorRef({is_matmul1, tuple2, weight_m_});
649 } else if (post_layernorm || !layernorm_fusion) {
650 reshape2 = VectorRef({is_reshape2, tuple2, var2});
651 matmul1 = VectorRef({is_matmul1, tuple2, weight_m_, bias_m_});
652 } else {
653 reshape2 = VectorRef({is_reshape2, add, var2});
654 matmul1 = VectorRef({is_matmul1, tuple2, weight_m_, bias_m_});
655 }
656 auto act = VectorRef({is_act_, matmul1});
657 auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
658 MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
659 auto matmul2 =
660 (is_position_bias) ? VectorRef({is_matmul2, matmul1, weight_p_}) : VectorRef({is_matmul2, act, weight_p_, bias_p_});
661
662 auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder3");
663 MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
664 auto var3 = std::make_shared<Var>("var3");
665 MS_CHECK_TRUE_RET(var3 != nullptr, {});
666 auto reshape3 = VectorRef({is_reshape3, matmul2, var3});
667 auto is_add3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add3");
668 auto add3 = VectorRef({is_add3, reshape2, reshape3});
669 if (is_layer_norm) return DefineLayerNorm(is_position_bias, add3, gamma3_, beta3_, eps3_);
670 if (!post_layernorm || !layernorm_fusion) return add3;
671 auto is_reshape4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
672 MS_CHECK_TRUE_RET(is_reshape4 != nullptr, {});
673 auto var4 = std::make_shared<Var>("var4");
674 MS_CHECK_TRUE_RET(var4 != nullptr, {});
675 auto reshape4 = VectorRef({is_reshape4, add3, var4});
676 if (layernorm_fusion) {
677 auto layer_norm = VectorRef({is_layernorm1_, reshape4, gamma1_, beta1_});
678 auto is_tuple3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item3");
679 auto var_tuple3 = std::make_shared<Var>("var_tuple3");
680 tuple3 = VectorRef({is_tuple3, layer_norm, var_tuple3});
681 } else {
682 tuple3 = DefineLayerNorm(is_position_bias, reshape4, gamma1_, beta1_, eps1_);
683 }
684 auto is_reshape5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-encoder");
685 MS_CHECK_TRUE_RET(is_reshape5 != nullptr, {});
686 auto var5 = std::make_shared<Var>("var5");
687 MS_CHECK_TRUE_RET(var5 != nullptr, {});
688 auto reshape5 = VectorRef({is_reshape5, tuple3, var5});
689 return reshape5;
690 }
691
DefinePatterns() const692 std::unordered_map<std::string, VectorRef> EncoderLayerFusion::DefinePatterns() const {
693 std::unordered_map<std::string, VectorRef> patterns;
694 if (!Init()) {
695 MS_LOG(ERROR) << "initial member failed.";
696 return patterns;
697 }
698 if (embedding_layer_) {
699 patterns[kPatternSigmaEmbedding] = DefinePatternEncoderSigma(false, true, false, false, false, false, true);
700 patterns[kPatternSigmaEmbeddingDistributed] =
701 DefinePatternEncoderSigma(false, true, true, false, false, false, true);
702 patterns[kPatternSigmaDistributedEmbedding] =
703 DefinePatternEncoderSigma(false, true, true, false, false, false, true);
704 patterns[kPatternSigmaDistributedEmbeddingMB] =
705 DefinePatternEncoderSigma(false, true, true, false, false, true, true);
706 } else {
707 // pangu sigma
708 patterns[kPatternSigmaDistributedEmbeddingMBGELU] =
709 DefinePatternEncoderSigma(false, true, true, false, false, true, true, true);
710 patterns[kPatternSigmaDistributed] = DefinePatternEncoderSigma(false, true, true, false, false);
711 patterns[kPatternSigmaMoeDistributed] = DefinePatternEncoderSigma(true, true, true, false, false);
712 patterns[kPatternSigmaMoeWithLastLayerNormDistributed] = DefinePatternEncoderSigma(true, true, true, true, false);
713 patterns[kPatternSigmaQueryLayerDistributed] = DefinePatternEncoderSigma(false, true, true, false, true);
714 patterns[kPatternSigmaQueryLayerDistributedMoe] = DefinePatternEncoderSigma(true, true, true, false, true);
715
716 patterns[kPatternSigma] = DefinePatternEncoderSigma(false, true, false, false, false);
717 patterns[kPatternSigmaWithLastLayerNorm] = DefinePatternEncoderSigma(false, true, false, true, false);
718 patterns[kPatternSigmaQuery] = DefinePatternEncoderSigma(false, true, false, false, true);
719 patterns[kPatternSigmaMoe] = DefinePatternEncoderSigma(true, true, false, false, false);
720 patterns[kPatternSigmaWithLastLayerNormDistributed] = DefinePatternEncoderSigma(false, true, true, true, false);
721
722 patterns[kPatternSigmaMoeWithLastLayerNorm] = DefinePatternEncoderSigma(true, true, true, true, false);
723
724 patterns[kPatternSigmaQueryLayerMoe] = DefinePatternEncoderSigma(true, true, false, false, true);
725 // multi batch-
726 // fast-gelu
727 patterns[kPatternSigmaMoeWithLastLayerNormDistributedMB] =
728 DefinePatternEncoderSigma(true, true, true, true, false, true);
729 patterns[kPatternSigmaWithLastLayerNormDistributedMB] =
730 DefinePatternEncoderSigma(false, true, true, true, false, true);
731 patterns[kPatternSigmaDistributedMB] = DefinePatternEncoderSigma(false, true, true, false, false, true);
732 patterns[kPatternSigmaQueryLayerDistributedMB] = DefinePatternEncoderSigma(false, true, true, false, true, true);
733 patterns[kPatternSigmaQueryLayerDistributedMBMoe] = DefinePatternEncoderSigma(true, true, true, false, true, true);
734 patterns[kPatternSigmaMoeDistributedMB] = DefinePatternEncoderSigma(true, true, true, false, false, true);
735 // gelu
736 patterns[kPatternSigmaMoeWithLastLayerNormDistributedMBGELU] =
737 DefinePatternEncoderSigma(true, true, true, true, false, true, false, true);
738 patterns[kPatternSigmaWithLastLayerNormDistributedMBGELU] =
739 DefinePatternEncoderSigma(false, true, true, true, false, true, false, true);
740 patterns[kPatternSigmaDistributedMBGELU] =
741 DefinePatternEncoderSigma(false, true, true, false, false, true, false, true);
742 patterns[kPatternSigmaQueryLayerDistributedMBGELU] =
743 DefinePatternEncoderSigma(true, true, true, false, true, true, false, true);
744 patterns[kPatternSigmaMoeDistributedMBGELU] =
745 DefinePatternEncoderSigma(true, true, true, false, false, true, false, true);
746 // pangu alpha
747 patterns[kPatternDistributedAlpha] = DefinePatternEncoderAlpha(false, true, false, false, true);
748 patterns[kPatternDistributedAlphaWithLastLayerNorm] = DefinePatternEncoderAlpha(false, true, true, false, true);
749 patterns[kPatternQueryLayerUsePastDistributed] = DefinePatternEncoderAlpha(false, true, false, true, true);
750 patterns[kPatternQueryLayerUsePast] = DefinePatternEncoderAlpha(false, false, false, true, true);
751 patterns[kPatternEncoderLayerPreNormUsePast] = DefinePatternEncoderAlpha(false, false, false, false, true);
752 patterns[kPatternEncoderLayerUsePastWithLastNorm] = DefinePatternEncoderAlpha(false, false, true, false, true);
753 patterns[kPatternEncoderLayerNormT5Pre] = DefinePatternEncoderLayer(false, false, true, true, true);
754 patterns[kPatternEncoderLayerPre] = DefinePatternEncoderLayer(false);
755 patterns[kPatternEncoderLayerPost] = DefinePatternEncoderLayer(true);
756 patterns[kPatternEncoderLayerPostNorm] = DefinePatternEncoderLayer(true, true);
757 patterns[kPatternEncoderLayerPreNorm] = DefinePatternEncoderLayer(false, true);
758 patterns[kPatternEncoderLayerT5Pre] = DefinePatternEncoderLayer(false, false, true, true);
759 patterns[kPatternEncoderLayerT5Post] = DefinePatternEncoderLayer(true, false, true, true);
760 }
761 return patterns;
762 }
763
IsUsePastAlpha(const std::string & pattern_name) const764 bool EncoderLayerFusion::IsUsePastAlpha(const std::string &pattern_name) const {
765 if (pattern_name == kPatternDistributedAlpha || pattern_name == kPatternDistributedAlphaWithLastLayerNorm ||
766 pattern_name == kPatternQueryLayerUsePastDistributed || pattern_name == kPatternQueryLayerUsePast ||
767 pattern_name == kPatternEncoderLayerPreNormUsePast || pattern_name == kPatternEncoderLayerUsePastWithLastNorm)
768 return true;
769 return false;
770 }
771
IsUsePastMB(const std::string & pattern_name) const772 bool EncoderLayerFusion::IsUsePastMB(const std::string &pattern_name) const {
773 if (pattern_name == kPatternSigmaWithLastLayerNormDistributedMB ||
774 pattern_name == kPatternSigmaWithLastLayerNormDistributedMBGELU ||
775 pattern_name == kPatternSigmaQueryLayerDistributedMB || pattern_name == kPatternSigmaQueryLayerDistributedMBMoe ||
776 pattern_name == kPatternSigmaQueryLayerDistributedMBGELU ||
777 pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMB ||
778 pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMBGELU ||
779 pattern_name == kPatternSigmaDistributedEmbeddingMB || pattern_name == kPatternSigmaDistributedEmbeddingMBGELU ||
780 pattern_name == kPatternSigmaDistributedMB || pattern_name == kPatternSigmaDistributedMBGELU ||
781 pattern_name == kPatternSigmaMoeDistributedMB || pattern_name == kPatternSigmaMoeDistributedMBGELU)
782 return true;
783 return false;
784 }
IsUsePast(const std::string & pattern_name) const785 bool EncoderLayerFusion::IsUsePast(const std::string &pattern_name) const {
786 if (pattern_name == kPatternSigmaDistributed || pattern_name == kPatternSigmaDistributedEmbedding ||
787 pattern_name == kPatternSigmaMoeDistributed || pattern_name == kPatternSigmaWithLastLayerNormDistributed ||
788 pattern_name == kPatternSigmaQueryLayerDistributed || pattern_name == kPatternSigmaQueryLayerDistributedMoe ||
789 pattern_name == kPatternSigmaMoeWithLastLayerNormDistributed || pattern_name == kPatternSigma ||
790 pattern_name == kPatternSigmaQuery || pattern_name == kPatternSigmaWithLastLayerNorm ||
791 pattern_name == kPatternSigmaMoe || pattern_name == kPatternSigmaMoeWithLastLayerNorm ||
792 pattern_name == kPatternSigmaWithLastLayerNorm || pattern_name == kPatternSigmaQueryLayerMoe ||
793 pattern_name == kPatternSigmaEmbedding || pattern_name == kPatternSigmaEmbeddingDistributed)
794 return true;
795 return false;
796 }
797
IsLastLayerNorm(const std::string & pattern_name) const798 bool EncoderLayerFusion::IsLastLayerNorm(const std::string &pattern_name) const {
799 if (pattern_name == kPatternEncoderLayerNormT5Pre || pattern_name == kPatternEncoderLayerUsePastWithLastNorm ||
800 pattern_name == kPatternDistributedAlphaWithLastLayerNorm ||
801 pattern_name == kPatternSigmaWithLastLayerNormDistributed ||
802 pattern_name == kPatternSigmaMoeWithLastLayerNormDistributed ||
803 pattern_name == kPatternSigmaMoeWithLastLayerNorm || pattern_name == kPatternSigmaWithLastLayerNorm ||
804 pattern_name == kPatternSigmaWithLastLayerNormDistributedMB ||
805 pattern_name == kPatternSigmaWithLastLayerNormDistributedMBGELU ||
806 pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMB ||
807 pattern_name == kPatternSigmaWithLastLayerNorm ||
808 pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMBGELU)
809 return true;
810 return false;
811 }
812
IsLayerNormFusion(const std::string & pattern_name) const813 bool EncoderLayerFusion::IsLayerNormFusion(const std::string &pattern_name) const {
814 if (pattern_name == kPatternEncoderLayerPostNorm || pattern_name == kPatternEncoderLayerPreNorm ||
815 pattern_name == kPatternEncoderLayerPreNormUsePast || pattern_name == kPatternQueryLayerUsePast ||
816 pattern_name == kPatternEncoderLayerUsePastWithLastNorm || pattern_name == kPatternDistributedAlpha ||
817 pattern_name == kPatternQueryLayerUsePastDistributed || pattern_name == kPatternDistributedAlphaWithLastLayerNorm)
818 return true;
819 return false;
820 }
821
IsMoe(const std::string & pattern_name) const822 bool EncoderLayerFusion::IsMoe(const std::string &pattern_name) const {
823 if (pattern_name == kPatternSigmaMoeDistributed || pattern_name == kPatternSigmaMoeWithLastLayerNormDistributed ||
824 pattern_name == kPatternSigmaMoe || pattern_name == kPatternSigmaMoeWithLastLayerNorm ||
825 pattern_name == kPatternSigmaQueryLayerMoe || pattern_name == kPatternSigmaQueryLayerDistributedMBGELU ||
826 pattern_name == kPatternSigmaQueryLayerDistributedMoe ||
827 pattern_name == kPatternSigmaQueryLayerDistributedMBMoe ||
828 pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMB ||
829 pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMBGELU ||
830 pattern_name == kPatternSigmaMoeDistributedMB || pattern_name == kPatternSigmaMoeDistributedMBGELU)
831 return true;
832 return false;
833 }
IsFastGeluDistributed(const std::string & pattern_name) const834 bool EncoderLayerFusion::IsFastGeluDistributed(const std::string &pattern_name) const {
835 if (pattern_name == kPatternSigmaDistributed || pattern_name == kPatternSigmaDistributedEmbedding ||
836 pattern_name == kPatternSigmaMoeDistributed || pattern_name == kPatternSigmaWithLastLayerNormDistributed ||
837 pattern_name == kPatternSigmaQueryLayerDistributed || pattern_name == kPatternSigmaQueryLayerDistributedMoe ||
838 pattern_name == kPatternSigmaMoeWithLastLayerNormDistributed ||
839 pattern_name == kPatternSigmaEmbeddingDistributed ||
840 pattern_name == kPatternSigmaWithLastLayerNormDistributedMB ||
841 pattern_name == kPatternSigmaQueryLayerDistributedMB || pattern_name == kPatternSigmaQueryLayerDistributedMBMoe ||
842 pattern_name == kPatternSigmaMoeWithLastLayerNormDistributedMB ||
843 pattern_name == kPatternSigmaDistributedEmbeddingMB || pattern_name == kPatternSigmaDistributedMB ||
844 pattern_name == kPatternSigmaMoeDistributedMB)
845 return true;
846 return false;
847 }
848
IsFastGelu(const std::string & pattern_name) const849 bool EncoderLayerFusion::IsFastGelu(const std::string &pattern_name) const {
850 if (pattern_name == kPatternSigma || pattern_name == kPatternSigmaQuery ||
851 pattern_name == kPatternSigmaWithLastLayerNorm || pattern_name == kPatternSigmaEmbedding ||
852 pattern_name == kPatternSigmaMoe || pattern_name == kPatternSigmaMoeWithLastLayerNorm ||
853 pattern_name == kPatternSigmaWithLastLayerNorm || pattern_name == kPatternSigmaQueryLayerMoe)
854 return true;
855 return false;
856 }
857
IsQueryLayer(const std::string & pattern_name) const858 bool EncoderLayerFusion::IsQueryLayer(const std::string &pattern_name) const {
859 if (pattern_name == kPatternQueryLayerUsePast || pattern_name == kPatternQueryLayerUsePastDistributed ||
860 pattern_name == kPatternSigmaQueryLayerDistributed || pattern_name == kPatternSigmaQueryLayerMoe ||
861 pattern_name == kPatternSigmaQueryLayerDistributedMoe || pattern_name == kPatternSigmaQueryLayerDistributedMB ||
862 pattern_name == kPatternSigmaQueryLayerDistributedMBMoe ||
863 pattern_name == kPatternSigmaQueryLayerDistributedMBGELU || pattern_name == kPatternSigmaQuery)
864 return true;
865 return false;
866 }
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const867 AnfNodePtr EncoderLayerFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
868 const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
869 if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
870 return nullptr;
871 }
872 bool mask = true;
873 is_layernorm_ = IsLastLayerNorm(pattern_name);
874 is_layernorm_fusion_ = IsLayerNormFusion(pattern_name);
875 is_use_past_ = IsUsePast(pattern_name) || IsUsePastMB(pattern_name) || IsUsePastAlpha(pattern_name);
876 is_moe_ = IsMoe(pattern_name);
877 is_fast_gelu_ = IsFastGelu(pattern_name) || IsFastGeluDistributed(pattern_name);
878 if (pattern_name == kPatternSigmaDistributedEmbeddingMB || pattern_name == kPatternSigmaDistributedEmbedding ||
879 pattern_name == kPatternSigmaDistributedEmbeddingMBGELU || pattern_name == kPatternSigmaEmbedding ||
880 pattern_name == kPatternSigmaEmbeddingDistributed)
881 is_embedding_layer_ = true;
882 if (pattern_name == kPatternEncoderLayerT5Pre || pattern_name == kPatternEncoderLayerT5Post ||
883 pattern_name == kPatternEncoderLayerNormT5Pre)
884 is_position_bias_ = true;
885 if (pattern_name == kPatternEncoderLayerPost || pattern_name == kPatternEncoderLayerPostNorm ||
886 pattern_name == kPatternEncoderLayerT5Post)
887 is_post_layernorm_ = true;
888 is_query_layer_ = IsQueryLayer(pattern_name);
889 return CreateMaskedEncoderLayerFusionNode(func_graph, equiv, node, is_post_layernorm_, mask);
890 }
891
IsActGELU(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const VarPtr & input_prim) const892 bool EncoderLayerFusion::IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
893 const VarPtr &input_prim) const {
894 auto act_input = GetAttribute(func_graph, equiv, is_act_);
895 MS_ASSERT(act_input != nullptr);
896 auto act_primitive = ops::GetOperator<ops::Activation>(act_input);
897 MS_CHECK_TRUE_RET(act_primitive != nullptr, false);
898 auto act_primitive_c = act_primitive->GetPrim();
899 if (act_primitive_c->GetAttr(ops::kActivationType) == nullptr ||
900 act_primitive->get_activation_type() != mindspore::GELU) {
901 return false;
902 }
903 return true;
904 }
905
GetEps(const EquivPtr & equiv,VarPtr node_name,float * eps) const906 STATUS EncoderLayerFusion::GetEps(const EquivPtr &equiv, VarPtr node_name, float *eps) const {
907 if ((*equiv)[node_name] == nullptr || !utils::isa<AnfNodePtr>((*equiv)[node_name])) {
908 MS_LOG(ERROR) << node_name << " is not anfnodeptr";
909 return RET_ERROR;
910 }
911 AnfNodePtr node = utils::cast<AnfNodePtr>((*equiv)[node_name]);
912 MS_ASSERT(node != nullptr);
913 if (utils::isa<ValueNodePtr>(node)) {
914 auto value_ptr_node = utils::cast<ValueNodePtr>(node);
915 auto value_node = utils::cast<ValuePtr>(value_ptr_node->value());
916 if (value_node->isa<tensor::Tensor>()) {
917 auto tensor = value_node->cast<tensor::TensorPtr>();
918 MS_EXCEPTION_IF_NULL(tensor);
919 *eps = *reinterpret_cast<float *>(tensor->data().data());
920 return RET_OK;
921 }
922 }
923 return RET_ERROR;
924 }
925
GetAttribute(const FuncGraphPtr & func_graph,const EquivPtr & equiv,VarPtr node_name) const926 AnfNodePtr EncoderLayerFusion::GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
927 VarPtr node_name) const {
928 if ((*equiv)[node_name] == nullptr || !utils::isa<AnfNodePtr>((*equiv)[node_name])) {
929 MS_LOG(ERROR) << node_name << "is not AnfNodePtr";
930 return nullptr;
931 }
932 AnfNodePtr node = utils::cast<AnfNodePtr>((*equiv)[node_name]);
933 MS_ASSERT(node != nullptr);
934 if (node == nullptr || !utils::isa<CNodePtr>(node)) {
935 auto manager = func_graph->manager();
936 if (manager == nullptr) {
937 return nullptr;
938 }
939 auto users = manager->node_users();
940 auto it = users.find(node);
941 if (it != users.end()) {
942 node = it->second.front().first;
943 }
944 if (node == nullptr || !utils::isa<CNodePtr>(node)) {
945 return nullptr;
946 }
947 }
948 auto cnode = utils::cast<CNodePtr>(node);
949 MS_ASSERT(cnode != nullptr);
950 auto input = cnode->input(0);
951 return input;
952 }
953
CheckPattern(const FuncGraphPtr & func_graph,const EquivPtr & equiv,int * head_num,int * head_size,float * eps1,float * eps2,float * eps3,float * scale) const954 STATUS EncoderLayerFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num,
955 int *head_size, float *eps1, float *eps2, float *eps3, float *scale) const {
956 auto attn_input = GetAttribute(func_graph, equiv, is_attention_);
957 MS_ASSERT(attn_input != nullptr);
958 auto attn_prim = ops::GetOperator<ops::Attention>(attn_input);
959 if (attn_prim->GetAttr(ops::kNumHeads) != nullptr) *head_num = attn_prim->get_head_num();
960 if (attn_prim->GetAttr(ops::kSizePerHead) != nullptr) *head_size = attn_prim->get_head_size();
961 if (attn_prim->GetAttr(ops::kPositionBias1) != nullptr) is_position_bias_ = attn_prim->get_position_bias();
962 if (attn_prim->GetAttr(ops::kScale) != nullptr) *scale = attn_prim->get_scale();
963 if (is_layernorm_fusion_) {
964 auto layrn1_input = GetAttribute(func_graph, equiv, is_layernorm1_);
965 auto layrn1_prim = ops::GetOperator<ops::LayerNormFusion>(layrn1_input);
966 *eps1 = layrn1_prim->get_epsilon();
967 auto layrn2_input = GetAttribute(func_graph, equiv, is_layernorm2_);
968 auto layrn2_prim = ops::GetOperator<ops::LayerNormFusion>(layrn2_input);
969 *eps2 = layrn2_prim->get_epsilon();
970 if (is_layernorm_) {
971 auto layrn3_input = GetAttribute(func_graph, equiv, is_layernorm3_);
972 auto layrn3_prim = ops::GetOperator<ops::LayerNormFusion>(layrn3_input);
973 *eps3 = layrn3_prim->get_epsilon();
974 }
975 } else {
976 MS_CHECK_TRUE_MSG(GetEps(equiv, eps1_, eps1) == RET_OK, RET_ERROR, "not found eps1");
977 MS_CHECK_TRUE_MSG(GetEps(equiv, eps2_, eps2) == RET_OK, RET_ERROR, "not found eps2");
978 if (is_layernorm_) {
979 MS_CHECK_TRUE_MSG(GetEps(equiv, eps3_, eps3) == RET_OK, RET_ERROR, "not found eps3");
980 }
981 }
982 act_type_ = (is_position_bias_) ? (ActType::ActType_Relu)
983 : (is_fast_gelu_) ? (ActType::ActType_FastGelu) : (ActType::ActType_Gelu);
984 if (!is_position_bias_ && !is_use_past_ && !is_query_layer_) {
985 if (!IsActGELU(func_graph, equiv, is_act_)) {
986 return RET_ERROR;
987 }
988 }
989 return RET_OK;
990 }
991
CreatePrim(const FuncGraphPtr & func_graph,const EquivPtr & equiv,int64_t ffn_hidden_size,int64_t expert_num,int64_t expert_offset,float capacity_factor) const992 std::shared_ptr<ops::EncoderLayer> EncoderLayerFusion::CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
993 int64_t ffn_hidden_size, int64_t expert_num,
994 int64_t expert_offset, float capacity_factor) const {
995 auto encoder_layer_prim = std::make_shared<ops::EncoderLayer>();
996 if (encoder_layer_prim == nullptr) {
997 MS_LOG(ERROR) << "Build enoder layer primitive failed.";
998 return nullptr;
999 }
1000 int head_num = 0;
1001 int head_size = 0;
1002 float eps1 = 1e-5;
1003 float eps2 = 1e-5;
1004 float eps3 = 1e-5;
1005 float scale = 1.0f;
1006 if (CheckPattern(func_graph, equiv, &head_num, &head_size, &eps1, &eps2, &eps3, &scale)) {
1007 return nullptr;
1008 }
1009 encoder_layer_prim->Init(head_num, head_size, eps1, eps2, eps3, ffn_hidden_size, expert_num, expert_offset,
1010 capacity_factor, is_position_bias_, is_post_layernorm_, scale, act_type_, is_layernorm_,
1011 is_use_past_, is_query_layer_, is_moe_, is_embedding_layer_);
1012
1013 return encoder_layer_prim;
1014 }
1015
InitAttributes(AnfNodePtr k_past,AnfNodePtr begin_expert_ids,AnfNodePtr weight_m,AnfNodePtr expert_capacity_node,int * ffn_hidden_size,int * expert_num,int * expert_offset,float * capacity_factor) const1016 STATUS EncoderLayerFusion::InitAttributes(AnfNodePtr k_past, AnfNodePtr begin_expert_ids, AnfNodePtr weight_m,
1017 AnfNodePtr expert_capacity_node, int *ffn_hidden_size, int *expert_num,
1018 int *expert_offset, float *capacity_factor) const {
1019 auto base_shape_ptr = weight_m->Shape();
1020 MS_CHECK_TRUE_RET(base_shape_ptr != nullptr, RET_ERROR);
1021 auto input_shape_ptr = base_shape_ptr->cast<abstract::ShapePtr>();
1022 MS_CHECK_TRUE_RET(input_shape_ptr != nullptr, RET_ERROR);
1023 auto input_shape = input_shape_ptr->shape();
1024 MS_CHECK_TRUE_RET(input_shape.size() >= C2NUM, RET_ERROR);
1025 if (is_moe_) {
1026 auto begin_expert_ids_node = begin_expert_ids->cast<ValueNodePtr>();
1027 *expert_num = (int64_t)input_shape[0];
1028 *expert_offset = CastToInt(begin_expert_ids_node->value())[0];
1029 auto base_shape_k = k_past->Shape();
1030 auto k_shape_ptr = base_shape_k->cast<abstract::ShapePtr>();
1031 auto k_shape = k_shape_ptr->shape();
1032 int seq = static_cast<int>(k_shape[C3NUM]);
1033 auto expert_capacity_value_node = utils::cast<ValuePtr>(utils::cast<ValueNodePtr>(expert_capacity_node)->value());
1034 if (expert_capacity_value_node->isa<tensor::Tensor>()) {
1035 auto tensor = expert_capacity_value_node->cast<tensor::TensorPtr>();
1036 auto expert_capacity = *(reinterpret_cast<float16 *>(tensor->data().data()));
1037 float cast_expert_capacity = Float16::ToFloat32(expert_capacity);
1038 *capacity_factor = (cast_expert_capacity) * (*expert_num) / seq;
1039 }
1040 *ffn_hidden_size = static_cast<int>(input_shape[C2NUM]);
1041 } else {
1042 *ffn_hidden_size = static_cast<int>(input_shape[1]);
1043 }
1044 return RET_OK;
1045 }
1046
CreateMaskedEncoderLayerFusionNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const AnfNodePtr & node,bool post_layernorm,bool mask=true) const1047 CNodePtr EncoderLayerFusion::CreateMaskedEncoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
1048 const AnfNodePtr &node, bool post_layernorm,
1049 bool mask = true) const {
1050 MS_ASSERT(func_graph != nullptr);
1051 MS_ASSERT(equiv != nullptr);
1052 MS_ASSERT(node != nullptr);
1053 AnfNodePtr v_past, k_past, batch_valid_length, embedding_table, expert_ids, begin_expert_ids, expert_capacity_node,
1054 position_ids, embedding_table_input, input_ids, current_index, embedding_table_pos;
1055 auto input = utils::cast<AnfNodePtr>((*equiv)[input_]);
1056 if (is_moe_) {
1057 expert_ids = utils::cast<AnfNodePtr>((*equiv)[expert_ids_]);
1058 begin_expert_ids = utils::cast<AnfNodePtr>((*equiv)[begin_expert_ids_]);
1059 expert_capacity_node = utils::cast<AnfNodePtr>((*equiv)[expert_capacity_]);
1060 }
1061 if (is_use_past_) {
1062 k_past = utils::cast<AnfNodePtr>((*equiv)[k_past_]);
1063 v_past = utils::cast<AnfNodePtr>((*equiv)[v_past_]);
1064 }
1065 auto weight_qkv = utils::cast<AnfNodePtr>((*equiv)[weight_attn_qkv_]);
1066 auto weight_attn_o = utils::cast<AnfNodePtr>((*equiv)[weight_attn_o_]);
1067 auto weight_m = utils::cast<AnfNodePtr>((*equiv)[weight_m_]);
1068 auto weight_p = utils::cast<AnfNodePtr>((*equiv)[weight_p_]);
1069 auto gamma1 = utils::cast<AnfNodePtr>((*equiv)[gamma1_]);
1070 auto gamma2 = utils::cast<AnfNodePtr>((*equiv)[gamma2_]);
1071 AnfNodePtr input_mask = mask ? utils::cast<AnfNodePtr>((*equiv)[mask_]) : nullptr;
1072 int ffn_hidden_size, expert_num = 1, expert_offset = 0;
1073 float capacity_factor = 0;
1074 if (InitAttributes(k_past, begin_expert_ids, weight_m, expert_capacity_node, &ffn_hidden_size, &expert_num,
1075 &expert_offset, &capacity_factor)) {
1076 MS_LOG(ERROR) << "Init Attributes failed.";
1077 return nullptr;
1078 }
1079 auto encoder_layer_prim = CreatePrim(func_graph, equiv, ffn_hidden_size, expert_num, expert_offset, capacity_factor);
1080 auto encoder_layer_prim_c = encoder_layer_prim->GetPrim();
1081 auto value_node = NewValueNode(encoder_layer_prim_c);
1082 std::vector<AnfNodePtr> new_node_inputs = {value_node, input};
1083 if (is_position_bias_) {
1084 auto position_bias = utils::cast<AnfNodePtr>((*equiv)[position_bias_]);
1085 new_node_inputs.insert(new_node_inputs.end(), {gamma1, weight_qkv});
1086 if (mask) new_node_inputs.push_back(input_mask);
1087 new_node_inputs.insert(new_node_inputs.end(), {position_bias, weight_attn_o, gamma2, weight_m, weight_p});
1088 if (is_layernorm_) {
1089 auto gamma3 = utils::cast<AnfNodePtr>((*equiv)[gamma3_]);
1090 new_node_inputs.push_back(gamma3);
1091 }
1092 } else {
1093 auto bias_attn_qkv = utils::cast<AnfNodePtr>((*equiv)[bias_attn_qkv_]);
1094 auto bias_attn_o = utils::cast<AnfNodePtr>((*equiv)[bias_attn_o_]);
1095 auto bias_m = utils::cast<AnfNodePtr>((*equiv)[bias_m_]);
1096 auto bias_p = utils::cast<AnfNodePtr>((*equiv)[bias_p_]);
1097 auto beta1 = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
1098 auto beta2 = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
1099 if (!is_post_layernorm_) {
1100 if (is_use_past_) new_node_inputs.insert(new_node_inputs.end(), {k_past, v_past});
1101 if (is_query_layer_) {
1102 auto input_q = utils::cast<AnfNodePtr>((*equiv)[input_q_]);
1103 auto weight_q = utils::cast<AnfNodePtr>((*equiv)[weight_attn_q_]);
1104 new_node_inputs.insert(new_node_inputs.end(), {gamma1, beta1, input_q, weight_q, weight_qkv, bias_attn_qkv});
1105 } else {
1106 new_node_inputs.insert(new_node_inputs.end(), {gamma1, beta1, weight_qkv, bias_attn_qkv});
1107 }
1108 if (mask) new_node_inputs.push_back(input_mask);
1109 new_node_inputs.insert(new_node_inputs.end(), {weight_attn_o, bias_attn_o, gamma2, beta2});
1110 if (is_moe_) new_node_inputs.push_back(expert_ids);
1111 new_node_inputs.insert(new_node_inputs.end(), {weight_m, bias_m, weight_p, bias_p});
1112 } else {
1113 new_node_inputs.insert(new_node_inputs.end(), {weight_qkv, bias_attn_qkv});
1114 if (mask) new_node_inputs.push_back(input_mask);
1115 new_node_inputs.insert(new_node_inputs.end(), {weight_attn_o, bias_attn_o, gamma1, beta1, weight_m, bias_m,
1116 weight_p, bias_p, gamma2, beta2});
1117 }
1118 if (is_layernorm_) {
1119 auto beta3 = utils::cast<AnfNodePtr>((*equiv)[beta3_]);
1120 auto gamma3 = utils::cast<AnfNodePtr>((*equiv)[gamma3_]);
1121 new_node_inputs.insert(new_node_inputs.end(), {gamma3, beta3});
1122 }
1123 }
1124 auto inputs = func_graph->get_inputs();
1125 MS_CHECK_TRUE_RET(inputs.size() > C2NUM, nullptr);
1126 if (is_query_layer_) {
1127 embedding_table_pos = utils::cast<AnfNodePtr>((*equiv)[embedding_table_pos_]);
1128 embedding_table = utils::cast<AnfNodePtr>((*equiv)[embedding_table_]);
1129 new_node_inputs.insert(new_node_inputs.end(), {embedding_table, embedding_table_pos, inputs.end()[-2],
1130 inputs.end()[-3], inputs.end()[-1]});
1131 } else if (is_use_past_) { // temporary solution
1132 if (is_embedding_layer_) {
1133 embedding_table_input = utils::cast<AnfNodePtr>((*equiv)[embedding_table_input_]);
1134 embedding_table_pos = utils::cast<AnfNodePtr>((*equiv)[embedding_table_pos_]);
1135 new_node_inputs.insert(new_node_inputs.end(), {embedding_table_input, embedding_table_pos});
1136 }
1137 new_node_inputs.insert(new_node_inputs.end(), {inputs.end()[-3], inputs.end()[-1]});
1138 }
1139 auto new_node = func_graph->NewCNode(new_node_inputs);
1140 auto old_node = node->cast<CNodePtr>();
1141 MS_CHECK_TRUE_RET(old_node->abstract() != nullptr, nullptr);
1142 new_node->set_abstract(old_node->abstract()->Clone());
1143 new_node->set_fullname_with_scope(node->fullname_with_scope() + "/encoder_layer");
1144 return new_node;
1145 }
1146 } // namespace mindspore::opt
1147