1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/decoder_layer_fusion.h"
19 #include <functional>
20 #include <utility>
21 #include <vector>
22 #include <algorithm>
23 #include "mindspore/core/ops/sequence_ops.h"
24 #include "mindspore/core/ops/nn_ops.h"
25 #include "mindspore/core/ops/math_ops.h"
26 #include "mindspore/core/ops/lite_ops.h"
27 #include "mindspore/core/ops/array_ops.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "nnacl/op_base.h"
30 #include "ops/tuple_get_item.h"
31 #include "tools/common/tensor_util.h"
32 #include "ops/op_utils.h"
33
34 namespace mindspore::opt {
35 namespace {
36 const auto &p1 = std::placeholders::_1;
37 } // namespace
38
Init() const39 bool DecoderLayerFusion::Init() const {
40 hidden_stats_ = std::make_shared<Var>("input");
41 MS_CHECK_TRUE_RET(hidden_stats_ != nullptr, false);
42 encoder_output_ = std::make_shared<Var>("input");
43 MS_CHECK_TRUE_RET(encoder_output_ != nullptr, false);
44 beta1_ = std::make_shared<Var>("beta1");
45 MS_CHECK_TRUE_RET(beta1_ != nullptr, false);
46 gamma1_ = std::make_shared<Var>("gamma1");
47 MS_CHECK_TRUE_RET(gamma1_ != nullptr, false);
48 beta2_ = std::make_shared<Var>("beta2");
49 MS_CHECK_TRUE_RET(beta2_ != nullptr, false);
50 gamma2_ = std::make_shared<Var>("gamma2");
51 MS_CHECK_TRUE_RET(gamma2_ != nullptr, false);
52 beta3_ = std::make_shared<Var>("beta3");
53 MS_CHECK_TRUE_RET(beta3_ != nullptr, false);
54 gamma3_ = std::make_shared<Var>("gamma3");
55 MS_CHECK_TRUE_RET(gamma3_ != nullptr, false);
56 gamma4_ = std::make_shared<Var>("gamma4");
57 MS_CHECK_TRUE_RET(gamma4_ != nullptr, false);
58 beta4_ = std::make_shared<Var>("beta4");
59 MS_CHECK_TRUE_RET(beta4_ != nullptr, false);
60 weight_attn_qkv_ = std::make_shared<Var>("weight_attn_qkv");
61 MS_CHECK_TRUE_RET(weight_attn_qkv_ != nullptr, false);
62 weight_attn_q_ = std::make_shared<Var>("weight_attn_q_");
63 MS_CHECK_TRUE_RET(weight_attn_q_ != nullptr, false);
64 weight_attn_kv_ = std::make_shared<Var>("weight_attn_kv_");
65 MS_CHECK_TRUE_RET(weight_attn_kv_ != nullptr, false);
66 weight_attn_o_ = std::make_shared<CondVar>(IsParamNode, "weight_attn_o");
67 MS_CHECK_TRUE_RET(weight_attn_o_ != nullptr, false);
68 weight_attn_cross_o_ = std::make_shared<CondVar>(IsParamNode, "weight_attn_cross_o_");
69 MS_CHECK_TRUE_RET(weight_attn_cross_o_ != nullptr, false);
70 weight_m_ = std::make_shared<CondVar>(IsParamNode, "weight_m");
71 MS_CHECK_TRUE_RET(weight_m_ != nullptr, false);
72 weight_p_ = std::make_shared<CondVar>(IsParamNode, "weight_p");
73 MS_CHECK_TRUE_RET(weight_p_ != nullptr, false);
74 bias_attn_qkv_ = std::make_shared<Var>("bias_attn_qkv");
75 MS_CHECK_TRUE_RET(bias_attn_qkv_ != nullptr, false);
76 bias_attn_o_ = std::make_shared<CondVar>(IsParamNode, "bias_attn_o");
77 MS_CHECK_TRUE_RET(bias_attn_o_ != nullptr, false);
78 bias_attn_cross_qkv_ = std::make_shared<Var>("bias_attn_cross_qkv_");
79 MS_CHECK_TRUE_RET(bias_attn_cross_qkv_ != nullptr, false);
80 bias_attn_cross_o_ = std::make_shared<CondVar>(IsParamNode, "bias_attn_cross_o_");
81 MS_CHECK_TRUE_RET(bias_attn_cross_o_ != nullptr, false);
82 bias_m_ = std::make_shared<CondVar>(IsParamNode, "bias_m");
83 MS_CHECK_TRUE_RET(bias_m_ != nullptr, false);
84 bias_p_ = std::make_shared<CondVar>(IsParamNode, "bias_p");
85 MS_CHECK_TRUE_RET(bias_p_ != nullptr, false);
86 mask_ = std::make_shared<Var>("mask");
87 MS_CHECK_TRUE_RET(mask_ != nullptr, false);
88 cross_mask_ = std::make_shared<Var>("cross_mask_");
89 MS_CHECK_TRUE_RET(cross_mask_ != nullptr, false);
90 is_attention_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAttention), "is_attention");
91 MS_CHECK_TRUE_RET(is_attention_ != nullptr, false);
92 is_attention_cross_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAttention), "is_attention_cross");
93 MS_CHECK_TRUE_RET(is_attention_cross_ != nullptr, false);
94 is_layernorm1_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm1");
95 MS_CHECK_TRUE_RET(is_layernorm1_ != nullptr, false);
96 is_layernorm2_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm2");
97 MS_CHECK_TRUE_RET(is_layernorm2_ != nullptr, false);
98 is_layernorm3_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimLayerNormFusion), "layer_norm3");
99 MS_CHECK_TRUE_RET(is_layernorm3_ != nullptr, false);
100 position_bias_ = std::make_shared<Var>("position_bias");
101 MS_CHECK_TRUE_RET(position_bias_ != nullptr, false);
102 position_bias_cross_ = std::make_shared<Var>("position_bias_cross_");
103 MS_CHECK_TRUE_RET(position_bias_ != nullptr, false);
104 is_act_ = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimActivation), "activation");
105 MS_CHECK_TRUE_RET(is_act_ != nullptr, false);
106 eps1_ = std::make_shared<Var>("eps1_");
107 MS_CHECK_TRUE_RET(eps1_ != nullptr, false);
108 eps2_ = std::make_shared<Var>("eps2_");
109 MS_CHECK_TRUE_RET(eps2_ != nullptr, false);
110 eps3_ = std::make_shared<Var>("eps3_");
111 MS_CHECK_TRUE_RET(eps3_ != nullptr, false);
112 eps4_ = std::make_shared<Var>("eps4_");
113 MS_CHECK_TRUE_RET(eps4_ != nullptr, false);
114 return true;
115 }
116
getTuple(bool post_layernorm,bool layernorm_fusion=false,bool is_position_bias=false) const117 VectorRef DecoderLayerFusion::getTuple(bool post_layernorm, bool layernorm_fusion = false,
118 bool is_position_bias = false) const {
119 auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-decoder");
120 MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
121 auto var1 = std::make_shared<Var>("var1-reshape");
122 MS_CHECK_TRUE_RET(var1 != nullptr, {});
123 auto reshape1 = VectorRef({is_reshape1, hidden_stats_, var1});
124 VectorRef layer_norm, tuple;
125 if (!layernorm_fusion) {
126 return DefineLayerNorm(reshape1, gamma1_, beta1_, eps1_);
127 }
128 layer_norm = VectorRef({is_layernorm1_, reshape1, gamma1_, beta1_});
129 auto is_tuple = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_itme");
130 auto var_tuple = std::make_shared<Var>("var_tuple");
131 tuple = VectorRef({is_tuple, layer_norm, var_tuple});
132 return tuple;
133 }
134
DefineLayerNorm(VectorRef input,VarPtr gamma,VarPtr beta,VarPtr eps) const135 VectorRef DecoderLayerFusion::DefineLayerNorm(VectorRef input, VarPtr gamma, VarPtr beta, VarPtr eps) const {
136 auto is_sqr = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSquare), "sqr2");
137 MS_CHECK_TRUE_RET(is_sqr != nullptr, {});
138 auto sqr = VectorRef({is_sqr, input});
139 auto var1 = std::make_shared<Var>("var1");
140 MS_CHECK_TRUE_RET(var1 != nullptr, {});
141 auto is_reduce = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReduceFusion), "reduce");
142 MS_CHECK_TRUE_RET(is_reduce != nullptr, {});
143 auto reduce = VectorRef({is_reduce, sqr, var1});
144 auto is_add = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is-add");
145 MS_CHECK_TRUE_RET(is_add != nullptr, {});
146 auto add = VectorRef({is_add, reduce, eps});
147 auto is_sqrt = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimSqrt), "sqr2");
148 MS_CHECK_TRUE_RET(is_sqrt != nullptr, {});
149 auto sqrt = VectorRef({is_sqrt, add});
150 auto is_div = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimRealDiv), "real-div");
151 MS_CHECK_TRUE_RET(is_div != nullptr, {});
152 auto real_div = VectorRef({is_div, input, sqrt});
153 auto is_mul = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMulFusion), "mul");
154 MS_CHECK_TRUE_RET(is_mul != nullptr, {});
155 auto mul = VectorRef({is_mul, real_div, gamma});
156 return mul;
157 }
158
DefinePatternDecoderLayer(bool post_layernorm=true,bool layernorm_fusion=false,bool is_position_bias=false,bool mask=true,bool is_layer_norm=false) const159 VectorRef DecoderLayerFusion::DefinePatternDecoderLayer(bool post_layernorm = true, bool layernorm_fusion = false,
160 bool is_position_bias = false, bool mask = true,
161 bool is_layer_norm = false) const {
162 auto is_reshape1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-decoder");
163 MS_CHECK_TRUE_RET(is_reshape1 != nullptr, {});
164 auto var1 = std::make_shared<Var>("var1-reshape");
165 MS_CHECK_TRUE_RET(var1 != nullptr, {});
166 auto reshape1 = VectorRef({is_reshape1, hidden_stats_, var1});
167 VectorRef inputs, input_cross, tuple2, tuple3, matmul2, tuple4, tuple5;
168 if (is_position_bias) {
169 inputs = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
170 getTuple(post_layernorm, layernorm_fusion, is_position_bias),
171 getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_, weight_attn_o_,
172 position_bias_});
173 } else {
174 inputs = VectorRef({is_attention_, getTuple(post_layernorm, layernorm_fusion, is_position_bias),
175 getTuple(post_layernorm, layernorm_fusion, is_position_bias),
176 getTuple(post_layernorm, layernorm_fusion, is_position_bias), weight_attn_qkv_, weight_attn_o_,
177 bias_attn_qkv_, bias_attn_o_});
178 }
179 if (mask) inputs.push_back(mask_);
180 auto attention = VectorRef(inputs);
181 if (is_position_bias) {
182 tuple4 = attention;
183 } else {
184 auto is_tuple4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item4");
185 auto var_tuple4 = std::make_shared<Var>("var_tuple4");
186 tuple4 = VectorRef({is_tuple4, attention, var_tuple4});
187 }
188 auto is_add2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add2");
189 auto add2 = (post_layernorm)
190 ? VectorRef({is_add2, getTuple(post_layernorm, layernorm_fusion, is_position_bias), tuple4})
191 : VectorRef({is_add2, reshape1, tuple4});
192 if (layernorm_fusion) {
193 auto layer_norm2 = VectorRef({is_layernorm2_, add2, gamma2_, beta2_});
194 auto is_tuple2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item2");
195 auto var_tuple2 = std::make_shared<Var>("var_tuple2");
196 tuple2 = VectorRef({is_tuple2, layer_norm2, var_tuple2});
197 } else {
198 tuple2 = DefineLayerNorm(add2, gamma2_, beta2_, eps2_);
199 }
200 auto is_reshape2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-decoder2");
201 MS_CHECK_TRUE_RET(is_reshape2 != nullptr, {});
202 auto var2 = std::make_shared<Var>("var2");
203 MS_CHECK_TRUE_RET(var2 != nullptr, {});
204 auto reshape2 = VectorRef({is_reshape2, encoder_output_, var2});
205 if (is_position_bias) {
206 input_cross = VectorRef({is_attention_cross_, tuple2, reshape2, reshape2, weight_attn_q_, weight_attn_kv_,
207 weight_attn_cross_o_, position_bias_cross_});
208 } else {
209 input_cross = VectorRef({is_attention_cross_, tuple2, reshape2, reshape2, weight_attn_q_, weight_attn_kv_,
210 weight_attn_cross_o_, bias_attn_cross_qkv_, bias_attn_cross_o_});
211 }
212 if (mask) input_cross.push_back(cross_mask_);
213 auto attention_cross = VectorRef(input_cross);
214 if (is_position_bias) {
215 tuple5 = attention_cross;
216 } else {
217 auto is_tuple5 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item5");
218 auto var_tuple5 = std::make_shared<Var>("var_tuple5");
219 tuple5 = VectorRef({is_tuple5, attention_cross, var_tuple5});
220 }
221 auto is_add3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add3");
222 MS_CHECK_TRUE_RET(is_add2 != nullptr, {});
223 auto add3 = (post_layernorm) ? VectorRef({is_add3, tuple2, tuple5}) : VectorRef({is_add3, add2, tuple5});
224 if (layernorm_fusion) {
225 auto layer_norm3 = VectorRef({is_layernorm3_, add3, gamma3_, beta3_});
226 auto is_tuple3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimTupleGetItem), "tuple_get_item3");
227 auto var_tuple3 = std::make_shared<Var>("var_tuple3");
228 tuple3 = VectorRef({is_tuple3, layer_norm3, var_tuple3});
229 } else {
230 tuple3 = DefineLayerNorm(add3, gamma3_, beta3_, eps3_);
231 }
232 auto is_matmul1 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul1");
233 MS_CHECK_TRUE_RET(is_matmul1 != nullptr, {});
234 auto is_matmul2 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimMatMulFusion), "is_matmul2");
235 MS_CHECK_TRUE_RET(is_matmul2 != nullptr, {});
236 if (!is_position_bias) {
237 auto matmul1 = VectorRef({is_matmul1, tuple3, weight_m_, bias_m_});
238 auto act = VectorRef({is_act_, matmul1});
239 matmul2 = VectorRef({is_matmul2, act, weight_p_, bias_p_});
240 } else {
241 auto matmul1 = VectorRef({is_matmul1, tuple3, weight_m_});
242 matmul2 = VectorRef({is_matmul2, matmul1, weight_p_});
243 }
244 auto is_reshape3 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-decoder3");
245 MS_CHECK_TRUE_RET(is_reshape3 != nullptr, {});
246 auto var3 = std::make_shared<Var>("var3");
247 MS_CHECK_TRUE_RET(var3 != nullptr, {});
248 auto reshape3 = VectorRef({is_reshape3, matmul2, var3});
249 auto is_reshape4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimReshape), "reshape-decoder4");
250 MS_CHECK_TRUE_RET(is_reshape4 != nullptr, {});
251 auto var4 = std::make_shared<Var>("var4");
252 MS_CHECK_TRUE_RET(var4 != nullptr, {});
253 auto reshape4 = (post_layernorm) ? VectorRef({is_reshape4, tuple3, var4}) : VectorRef({is_reshape4, add3, var4});
254 auto is_add4 = std::make_shared<CondVar>(std::bind(IsOpType, p1, prim::kPrimAddFusion), "is_add4");
255 auto add4 = VectorRef({is_add4, reshape4, reshape3});
256 return (is_layer_norm) ? DefineLayerNorm(add4, gamma4_, beta4_, eps4_) : add4;
257 }
258
DefinePatterns() const259 std::unordered_map<std::string, VectorRef> DecoderLayerFusion::DefinePatterns() const {
260 std::unordered_map<std::string, VectorRef> patterns;
261 if (!Init()) {
262 MS_LOG(ERROR) << "initial member failed.";
263 return patterns;
264 }
265 patterns[kPatternDecoderLayerNormT5Pre] = DefinePatternDecoderLayer(false, false, true, true, true);
266 patterns[kPatternDecoderLayerPre] = DefinePatternDecoderLayer(false, true, false, true);
267 patterns[kPatternDecoderLayerPost] = DefinePatternDecoderLayer(true, true, false, true);
268 patterns[kPatternDecoderLayerNormPre] = DefinePatternDecoderLayer(false, false, false, true);
269 patterns[kPatternDecoderLayerNormPost] = DefinePatternDecoderLayer(true, false, false, true);
270 patterns[kPatternDecoderT5Pre] = DefinePatternDecoderLayer(false, false, true, true);
271 patterns[kPatternDecoderT5Post] = DefinePatternDecoderLayer(true, false, true, true);
272 return patterns;
273 }
274
Process(const std::string & pattern_name,const mindspore::FuncGraphPtr & func_graph,const mindspore::AnfNodePtr & node,const mindspore::EquivPtr & equiv) const275 AnfNodePtr DecoderLayerFusion::Process(const std::string &pattern_name, const mindspore::FuncGraphPtr &func_graph,
276 const mindspore::AnfNodePtr &node, const mindspore::EquivPtr &equiv) const {
277 if (func_graph == nullptr || node == nullptr || equiv == nullptr) {
278 return nullptr;
279 }
280 if (pattern_name == kPatternDecoderT5Pre || pattern_name == kPatternDecoderT5Post ||
281 pattern_name == kPatternDecoderLayerNormT5Pre) {
282 is_position_bias_ = true;
283 }
284 is_layernorm_ = false;
285 if (pattern_name == kPatternDecoderLayerNormT5Pre) {
286 is_layernorm_ = true;
287 }
288 if (pattern_name == kPatternDecoderLayerPre || pattern_name == kPatternDecoderLayerPost) {
289 is_layernorm_fusion_ = true;
290 }
291 bool mask = true;
292 bool post_layernorm = false;
293 if (pattern_name == kPatternDecoderLayerPost || pattern_name == kPatternDecoderT5Post ||
294 pattern_name == kPatternDecoderLayerNormPost) {
295 post_layernorm = true;
296 }
297 return CreateMaskedDecoderLayerFusionNode(func_graph, equiv, node, post_layernorm, mask);
298 } // namespace mindspore::opt
299
IsActGELU(const FuncGraphPtr & func_graph,const EquivPtr & equiv) const300 bool DecoderLayerFusion::IsActGELU(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const {
301 auto act_input = GetAttribute(func_graph, equiv, is_act_);
302 MS_ASSERT(act_input != nullptr);
303 auto act_primitive = ops::GetOperator<ops::Activation>(act_input);
304 MS_CHECK_TRUE_RET(act_primitive != nullptr, false);
305 auto act_primitive_c = act_primitive->GetPrim();
306 if (act_primitive_c->GetAttr(ops::kActivationType) == nullptr ||
307 act_primitive->get_activation_type() != mindspore::GELU) {
308 return false;
309 }
310 return true;
311 }
312
GetAttribute(const FuncGraphPtr & func_graph,const EquivPtr & equiv,VarPtr node_name) const313 AnfNodePtr DecoderLayerFusion::GetAttribute(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
314 VarPtr node_name) const {
315 if ((*equiv)[node_name] == nullptr || !utils::isa<AnfNodePtr>((*equiv)[node_name])) {
316 MS_LOG(ERROR) << node_name << "is not AnfNodePtr";
317 return nullptr;
318 }
319 AnfNodePtr node = utils::cast<AnfNodePtr>((*equiv)[node_name]);
320 MS_ASSERT(node != nullptr);
321 if (node == nullptr || !utils::isa<CNodePtr>(node)) {
322 auto manager = func_graph->manager();
323 if (manager == nullptr) {
324 return nullptr;
325 }
326 auto users = manager->node_users();
327 auto it = users.find(node);
328 if (it != users.end()) {
329 node = it->second.front().first;
330 }
331 if (node == nullptr || !utils::isa<CNodePtr>(node)) {
332 return nullptr;
333 }
334 }
335 auto cnode = utils::cast<CNodePtr>(node);
336 MS_ASSERT(cnode != nullptr);
337 auto input = cnode->input(0);
338 return input;
339 }
340
GetEps(const EquivPtr & equiv,VarPtr node_name,float * eps) const341 STATUS DecoderLayerFusion::GetEps(const EquivPtr &equiv, VarPtr node_name, float *eps) const {
342 if ((*equiv)[node_name] == nullptr || !utils::isa<AnfNodePtr>((*equiv)[node_name])) {
343 MS_LOG(ERROR) << node_name << " is not anfnodeptr";
344 return RET_ERROR;
345 }
346 AnfNodePtr node = utils::cast<AnfNodePtr>((*equiv)[node_name]);
347 MS_ASSERT(node != nullptr);
348 if (utils::isa<ValueNodePtr>(node)) {
349 auto value_ptr_node = utils::cast<ValueNodePtr>(node);
350 auto value_node = utils::cast<ValuePtr>(value_ptr_node->value());
351 if (value_node->isa<tensor::Tensor>()) {
352 auto tensor = value_node->cast<tensor::TensorPtr>();
353 MS_EXCEPTION_IF_NULL(tensor);
354 *eps = *reinterpret_cast<float *>(tensor->data().data());
355 return RET_OK;
356 }
357 }
358 return RET_ERROR;
359 }
360
CheckPattern(const FuncGraphPtr & func_graph,const EquivPtr & equiv,int * head_num,int * head_size,float * eps1,float * eps2,float * eps3,float * eps4,bool * is_position_bias1,bool * is_position_bias2,float * scale1,float * scale2) const361 STATUS DecoderLayerFusion::CheckPattern(const FuncGraphPtr &func_graph, const EquivPtr &equiv, int *head_num,
362 int *head_size, float *eps1, float *eps2, float *eps3, float *eps4,
363 bool *is_position_bias1, bool *is_position_bias2, float *scale1,
364 float *scale2) const {
365 auto attn_input = GetAttribute(func_graph, equiv, is_attention_);
366 MS_ASSERT(attn_input != nullptr);
367 auto attn_prim = ops::GetOperator<ops::Attention>(attn_input);
368 if (attn_prim->GetAttr(ops::kNumHeads) != nullptr) *head_num = attn_prim->get_head_num();
369 if (attn_prim->GetAttr(ops::kSizePerHead) != nullptr) *head_size = attn_prim->get_head_size();
370 if (attn_prim->GetAttr(ops::kPositionBias1) != nullptr) *is_position_bias1 = attn_prim->get_position_bias();
371 if (attn_prim->GetAttr(ops::kScale) != nullptr) *scale1 = attn_prim->get_scale();
372 auto attn_cross_input = GetAttribute(func_graph, equiv, is_attention_cross_);
373 MS_ASSERT(attn_cross_input != nullptr);
374 auto attn_cross_prim = ops::GetOperator<ops::Attention>(attn_cross_input);
375 if (attn_cross_prim->GetAttr(ops::kPositionBias1) != nullptr)
376 *is_position_bias2 = attn_cross_prim->get_position_bias();
377 if (attn_cross_prim->GetAttr(ops::kScale) != nullptr) *scale2 = attn_cross_prim->get_scale();
378 if (is_layernorm_fusion_) {
379 auto layrn1_input = GetAttribute(func_graph, equiv, is_layernorm1_);
380 auto layrn1_prim = ops::GetOperator<ops::LayerNormFusion>(layrn1_input);
381 if (layrn1_prim->GetAttr(ops::kEpsilon) != nullptr) *eps1 = layrn1_prim->get_epsilon();
382 auto layrn2_input = GetAttribute(func_graph, equiv, is_layernorm2_);
383 auto layrn2_prim = ops::GetOperator<ops::LayerNormFusion>(layrn2_input);
384 if (layrn2_prim->GetAttr(ops::kEpsilon) != nullptr) *eps2 = layrn2_prim->get_epsilon();
385 auto layrn3_input = GetAttribute(func_graph, equiv, is_layernorm3_);
386 auto layrn3_prim = ops::GetOperator<ops::LayerNormFusion>(layrn3_input);
387 if (layrn3_prim->GetAttr(ops::kEpsilon) != nullptr) *eps3 = layrn3_prim->get_epsilon();
388 } else {
389 MS_CHECK_TRUE_MSG(GetEps(equiv, eps1_, eps1) == RET_OK, RET_ERROR, "not found eps1");
390 MS_CHECK_TRUE_MSG(GetEps(equiv, eps2_, eps2) == RET_OK, RET_ERROR, "not found eps2");
391 MS_CHECK_TRUE_MSG(GetEps(equiv, eps3_, eps3) == RET_OK, RET_ERROR, "not found eps3");
392 if (is_layernorm_) {
393 MS_CHECK_TRUE_MSG(GetEps(equiv, eps4_, eps4) == RET_OK, RET_ERROR, "not found eps4");
394 }
395 }
396 if (!is_position_bias_) {
397 if (!IsActGELU(func_graph, equiv)) return RET_ERROR;
398 act_type_ = ActType::ActType_Gelu;
399 } else {
400 act_type_ = ActType::ActType_Relu;
401 }
402 return RET_OK;
403 }
404
CreatePrim(const FuncGraphPtr & func_graph,const EquivPtr & equiv,bool post_layernorm,int64_t ffn_hidden_size) const405 std::shared_ptr<ops::DecoderLayer> DecoderLayerFusion::CreatePrim(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
406 bool post_layernorm, int64_t ffn_hidden_size) const {
407 auto decoder_layer_prim = std::make_shared<ops::DecoderLayer>();
408 if (decoder_layer_prim == nullptr) {
409 MS_LOG(ERROR) << "Build decoder layer primitive failed.";
410 return nullptr;
411 }
412 int head_num = 0;
413 int head_size = 0;
414 float eps1 = 1e-6;
415 float eps2 = 1e-6;
416 float eps3 = 1e-6;
417 float eps4 = 1e-6;
418 bool is_position_bias1 = false;
419 bool is_position_bias2 = false;
420 float scale1 = 1.0f;
421 float scale2 = 1.0f;
422 if (CheckPattern(func_graph, equiv, &head_num, &head_size, &eps1, &eps2, &eps3, &eps4, &is_position_bias1,
423 &is_position_bias2, &scale1, &scale2)) {
424 return nullptr;
425 }
426 decoder_layer_prim->Init(head_num, head_size, eps1, eps2, eps3, eps4, ffn_hidden_size, is_position_bias1,
427 is_position_bias2, post_layernorm, scale1, scale2, act_type_, is_layernorm_);
428 return decoder_layer_prim;
429 }
430
CreateMaskedDecoderLayerFusionNode(const FuncGraphPtr & func_graph,const EquivPtr & equiv,const AnfNodePtr & node,bool post_layernorm=true,bool mask=true) const431 CNodePtr DecoderLayerFusion::CreateMaskedDecoderLayerFusionNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv,
432 const AnfNodePtr &node, bool post_layernorm = true,
433 bool mask = true) const {
434 MS_ASSERT(func_graph != nullptr);
435 MS_ASSERT(equiv != nullptr);
436 MS_ASSERT(node != nullptr);
437 auto input = utils::cast<AnfNodePtr>((*equiv)[hidden_stats_]);
438 MS_ASSERT(input != nullptr);
439 auto encoder_output = utils::cast<AnfNodePtr>((*equiv)[encoder_output_]);
440 MS_ASSERT(encoder_output != nullptr);
441 AnfNodePtr position_bias, input_mask, bias_attn_o, bias_attn_qkv, beta1, beta2, bias_m, bias_p, beta3,
442 bias_attn_cross_qkv, bias_attn_cross_o, position_bias_cross, gamma4, beta4;
443 auto weight_qkv = utils::cast<AnfNodePtr>((*equiv)[weight_attn_qkv_]);
444 auto weight_attn_o = utils::cast<AnfNodePtr>((*equiv)[weight_attn_o_]);
445 auto weight_attn_q = utils::cast<AnfNodePtr>((*equiv)[weight_attn_q_]);
446 auto weight_attn_kv = utils::cast<AnfNodePtr>((*equiv)[weight_attn_kv_]);
447 auto weight_attn_cross_o = utils::cast<AnfNodePtr>((*equiv)[weight_attn_cross_o_]);
448 auto weight_m = utils::cast<AnfNodePtr>((*equiv)[weight_m_]);
449 auto weight_p = utils::cast<AnfNodePtr>((*equiv)[weight_p_]);
450 if (is_position_bias_) {
451 position_bias = utils::cast<AnfNodePtr>((*equiv)[position_bias_]);
452 position_bias_cross = utils::cast<AnfNodePtr>((*equiv)[position_bias_cross_]);
453 } else {
454 bias_attn_o = utils::cast<AnfNodePtr>((*equiv)[bias_attn_o_]);
455 bias_attn_qkv = utils::cast<AnfNodePtr>((*equiv)[bias_attn_qkv_]);
456 bias_attn_cross_qkv = utils::cast<AnfNodePtr>((*equiv)[bias_attn_cross_qkv_]);
457 bias_attn_cross_o = utils::cast<AnfNodePtr>((*equiv)[bias_attn_cross_o_]);
458 bias_m = utils::cast<AnfNodePtr>((*equiv)[bias_m_]);
459 bias_p = utils::cast<AnfNodePtr>((*equiv)[bias_p_]);
460 beta1 = utils::cast<AnfNodePtr>((*equiv)[beta1_]);
461 beta2 = utils::cast<AnfNodePtr>((*equiv)[beta2_]);
462 beta3 = utils::cast<AnfNodePtr>((*equiv)[beta3_]);
463 if (is_layernorm_) beta4 = utils::cast<AnfNodePtr>((*equiv)[beta4_]);
464 }
465 auto gamma1 = utils::cast<AnfNodePtr>((*equiv)[gamma1_]);
466 auto gamma2 = utils::cast<AnfNodePtr>((*equiv)[gamma2_]);
467 auto gamma3 = utils::cast<AnfNodePtr>((*equiv)[gamma3_]);
468 if (is_layernorm_) gamma4 = utils::cast<AnfNodePtr>((*equiv)[gamma4_]);
469
470 input_mask = mask ? utils::cast<AnfNodePtr>((*equiv)[mask_]) : nullptr;
471 auto cross_mask = utils::cast<AnfNodePtr>((*equiv)[cross_mask_]);
472 auto base_shape_ptr = weight_m->Shape();
473 MS_EXCEPTION_IF_NULL(base_shape_ptr);
474 auto input_shape_ptr = base_shape_ptr->cast<abstract::ShapePtr>();
475 MS_EXCEPTION_IF_NULL(input_shape_ptr);
476 auto input_shape = input_shape_ptr->shape();
477 MS_ASSERT(input_shape != nullptr);
478 int ffn_hidden_size = (int64_t)input_shape[1];
479 auto decoder_layer_prim = CreatePrim(func_graph, equiv, post_layernorm, ffn_hidden_size);
480 MS_CHECK_TRUE_RET(decoder_layer_prim != nullptr, nullptr);
481 auto decoder_layer_prim_c = decoder_layer_prim->GetPrim();
482 MS_CHECK_TRUE_RET(decoder_layer_prim_c != nullptr, nullptr);
483 auto value_node = NewValueNode(decoder_layer_prim_c);
484 MS_CHECK_TRUE_RET(value_node != nullptr, nullptr);
485 std::vector<AnfNodePtr> new_node_inputs = {value_node, input, gamma1};
486 if (is_position_bias_) {
487 new_node_inputs.insert(new_node_inputs.end(), {weight_qkv});
488 if (mask) new_node_inputs.push_back(input_mask);
489 new_node_inputs.insert(new_node_inputs.end(),
490 {position_bias, weight_attn_o, gamma2, encoder_output, weight_attn_q, weight_attn_kv});
491 if (mask) new_node_inputs.push_back(cross_mask);
492 new_node_inputs.insert(new_node_inputs.end(),
493 {position_bias_cross, weight_attn_cross_o, gamma3, weight_m, weight_p});
494 if (is_layernorm_) new_node_inputs.push_back(gamma4);
495 } else {
496 new_node_inputs.insert(new_node_inputs.end(), {beta1, weight_qkv, bias_attn_qkv});
497 if (mask) new_node_inputs.push_back(input_mask);
498 new_node_inputs.insert(new_node_inputs.end(), {weight_attn_o, bias_attn_o, gamma2, beta2, encoder_output,
499 weight_attn_q, weight_attn_kv, bias_attn_cross_qkv});
500 if (mask) new_node_inputs.push_back(cross_mask);
501 new_node_inputs.insert(new_node_inputs.end(),
502 {weight_attn_cross_o, bias_attn_cross_o, gamma3, beta3, weight_m, bias_m, weight_p, bias_p});
503 if (is_layernorm_) new_node_inputs.insert(new_node_inputs.end(), {gamma4, beta4});
504 }
505 auto new_node = func_graph->NewCNode(new_node_inputs);
506 MS_CHECK_TRUE_RET(new_node != nullptr, nullptr);
507 auto old_node = node->cast<CNodePtr>();
508 MS_CHECK_TRUE_RET(old_node->abstract() != nullptr, nullptr);
509 new_node->set_abstract(old_node->abstract()->Clone());
510 new_node->set_fullname_with_scope(node->fullname_with_scope() + "/decoder_layer");
511 return new_node;
512 }
513 } // namespace mindspore::opt
514