1 /**
2 * Copyright 2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "plugin/device/ascend/optimizer/ir_fusion/inference_matmul_split_fusion.h"
17 #include <vector>
18 #include <set>
19 #include "plugin/device/ascend/optimizer/common/gllo_utils.h"
20 #include "mindspore/core/ops/nn_ops.h"
21 #include "mindspore/core/ops/math_ops.h"
22 #include "include/backend/optimizer/helper.h"
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/common/utils/utils.h"
26 #include "utils/ms_context.h"
27 #include "utils/trace_base.h"
28
29 namespace mindspore {
30 namespace opt {
31
Run(const FuncGraphPtr & graph)32 bool InferenceMatmulSplitFusion::Run(const FuncGraphPtr &graph) {
33 auto kernel_graph = graph->cast<KernelGraphPtr>();
34 MS_EXCEPTION_IF_NULL(kernel_graph);
35 bool changed = false;
36 auto ms_context = MsContext::GetInstance();
37 MS_EXCEPTION_IF_NULL(ms_context);
38 if (!ms_context->IsEnableInferBoost()) {
39 return false;
40 }
41 constexpr auto kInferenceMatmulSplitSiluName = "InferenceMatmulSplitSilu";
42 constexpr auto kInferenceMatmulSplitName = "InferenceMatmulSplit";
43 auto enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
44 auto enable_fusion =
45 (std::find(enable_op_list.begin(), enable_op_list.end(), kInferenceMatmulSplitName) != enable_op_list.end());
46 if (!enable_fusion) {
47 return false;
48 }
49 enable_fusion_silu =
50 (std::find(enable_op_list.begin(), enable_op_list.end(), kInferenceMatmulSplitSiluName) != enable_op_list.end());
51
52 std::string pattern_name = "";
53 auto node_list = TopoSort(graph->output());
54 std::reverse(node_list.begin(), node_list.end());
55 for (const auto &node : node_list) {
56 if (node == nullptr || !node->isa<CNode>()) {
57 continue;
58 }
59 auto cnode = node->cast<CNodePtr>();
60 auto node_name = common::AnfAlgo::GetCNodeName(cnode);
61 if (node_name != prim::kPrimSplitWithSize->name() && node_name != prim::kPrimSiLU->name()) {
62 continue;
63 }
64 if (visited_cnodes.find(cnode) != visited_cnodes.end()) {
65 continue;
66 }
67 pattern_name = GetFusionPatternName(cnode);
68 MS_LOG(DEBUG) << "fusion pattern is : " << pattern_name;
69 if (!pattern_name.empty()) {
70 auto new_node = Process(pattern_name, graph, node);
71 changed |= new_node != nullptr;
72 }
73 }
74 return changed;
75 }
76
GetSplitFusionPatternName(const CNodePtr & cnode) const77 std::string InferenceMatmulSplitFusion::GetSplitFusionPatternName(const CNodePtr &cnode) const {
78 std::string pattern_name = "";
79 auto reshape_node = common::AnfAlgo::GetInputNode(cnode, kIndex0);
80 if (reshape_node == nullptr || !reshape_node->isa<CNode>()) {
81 return "";
82 }
83 auto reshape_node_name = common::AnfAlgo::GetCNodeName(reshape_node);
84 if (reshape_node_name != prim::kPrimReshape->name()) {
85 MS_LOG(DEBUG) << "reshape node name is: " << reshape_node_name;
86 return "";
87 }
88 auto reshape_cnode = reshape_node->cast<CNodePtr>();
89 auto reshape_input_node = common::AnfAlgo::GetInputNode(reshape_cnode, kIndex0);
90 if (reshape_input_node != nullptr && reshape_input_node->isa<CNode>()) {
91 auto reshape_input_name = common::AnfAlgo::GetCNodeName(reshape_input_node);
92 if (reshape_input_name == prim::kPrimMatMul->name()) {
93 MS_LOG(DEBUG) << "process matmul reshape split fusion";
94 pattern_name = kPatternNameMatMulSplit;
95 } else if (reshape_input_name == prim::kPrimQuantBatchMatmul->name()) {
96 MS_LOG(DEBUG) << "process quant_batch_matmul reshape split fusion";
97 pattern_name = kPatternNameQuantbatchmatmulSplit;
98 } else if (reshape_input_name == prim::kPrimAdd->name()) {
99 auto bias_add_cnode = reshape_input_node->cast<CNodePtr>();
100 auto bias_input_node = common::AnfAlgo::GetInputNode(bias_add_cnode, kIndex0);
101 if (bias_input_node->isa<CNode>() &&
102 common::AnfAlgo::GetCNodeName(bias_input_node) == prim::kPrimMatMul->name()) {
103 MS_LOG(DEBUG) << "process matmul biasadd reshape split fusion";
104 pattern_name = kPatternNameMatMulBiasAddSplit;
105 }
106 }
107 }
108 return pattern_name;
109 }
110
GetFusionPatternName(const CNodePtr & cnode) const111 std::string InferenceMatmulSplitFusion::GetFusionPatternName(const CNodePtr &cnode) const {
112 std::string pattern_name = "";
113 auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
114 if (cnode_name == prim::kPrimSiLU->name()) {
115 if (!enable_fusion_silu) {
116 MS_LOG(DEBUG) << "disable matmul split silu fusion";
117 return "";
118 }
119 auto silu_input_node = common::AnfAlgo::GetInputNode(cnode, kIndex0);
120 auto silu_input_name = common::AnfAlgo::GetCNodeName(silu_input_node);
121 if (silu_input_name == prim::kPrimTupleGetItem->name()) {
122 auto silu_input_cnode = silu_input_node->cast<CNodePtr>();
123 auto item_input_node = common::AnfAlgo::GetInputNode(silu_input_cnode, kIndex0);
124 auto item_input_name = common::AnfAlgo::GetCNodeName(item_input_node);
125 if (item_input_name == prim::kPrimSplitWithSize->name()) {
126 auto item_input_cnode = item_input_node->cast<CNodePtr>();
127 auto split_pattern_name = GetSplitFusionPatternName(item_input_cnode);
128 if (!split_pattern_name.empty()) {
129 pattern_name = split_pattern_name + "Silu";
130 }
131 }
132 }
133 } else if (cnode_name == prim::kPrimSplitWithSize->name()) {
134 pattern_name = GetSplitFusionPatternName(cnode);
135 }
136 return pattern_name;
137 }
138
CheckMatMulDataFormat(const CNodePtr & matmul_cnode) const139 bool InferenceMatmulSplitFusion::CheckMatMulDataFormat(const CNodePtr &matmul_cnode) const {
140 MS_CHECK_TRUE_RET(matmul_cnode != nullptr, false);
141 size_t trans_a_index = 0;
142 size_t trans_b_index = 0;
143 auto cnode_name = common::AnfAlgo::GetCNodeName(matmul_cnode);
144 if (cnode_name == prim::kPrimQuantBatchMatmul->name()) {
145 trans_a_index = kIndex6;
146 trans_b_index = kIndex7;
147 } else if (cnode_name == prim::kPrimMatMul->name()) {
148 trans_a_index = kIndex3;
149 trans_b_index = kIndex4;
150 }
151 auto trans_a = matmul_cnode->input(trans_a_index)->cast<ValueNodePtr>();
152 MS_CHECK_TRUE_RET(trans_a != nullptr, false);
153 auto trans_b = matmul_cnode->input(trans_b_index)->cast<ValueNodePtr>();
154 MS_CHECK_TRUE_RET(trans_b != nullptr, false);
155 bool is_trans_a = GetValue<bool>(trans_a->value());
156 bool is_trans_b = GetValue<bool>(trans_b->value());
157 if (!is_trans_a && is_trans_b) {
158 return true;
159 }
160 return false;
161 }
162
GetSplitSizeLen(const CNodePtr & split_cnode) const163 size_t InferenceMatmulSplitFusion::GetSplitSizeLen(const CNodePtr &split_cnode) const {
164 auto split_size = split_cnode->input(kIndex2)->cast<ValueNodePtr>();
165 if (split_size == nullptr || !split_size->isa<ValueNode>()) {
166 MS_LOG(DEBUG) << "split size node is nullptr";
167 return 0;
168 }
169 auto split_size_shape = GetValue<std::vector<int64_t>>(split_size->value());
170 size_t split_size_len = split_size_shape.size();
171 return split_size_len;
172 }
173
CreateMatmulSplitPrim(const CNodePtr & split_cnode,size_t split_size_len,const std::string & pattern_name) const174 PrimitivePtr InferenceMatmulSplitFusion::CreateMatmulSplitPrim(const CNodePtr &split_cnode, size_t split_size_len,
175 const std::string &pattern_name) const {
176 PrimitivePtr matmul_split_prim = nullptr;
177 std::string prim_name = "";
178 auto iter = PatternPrimMap.find(split_size_len);
179 if (iter != PatternPrimMap.end()) {
180 auto iter_n = iter->second.find(pattern_name);
181 if (iter_n != iter->second.end()) {
182 prim_name = iter_n->second;
183 }
184 }
185 MS_CHECK_TRUE_RET(!prim_name.empty(), nullptr);
186 matmul_split_prim = std::make_shared<Primitive>(prim_name);
187 MS_CHECK_TRUE_RET(matmul_split_prim != nullptr, nullptr);
188 auto split_size = split_cnode->input(kIndex2)->cast<ValueNodePtr>();
189 matmul_split_prim->AddAttr("n_lens", split_size->value());
190 return matmul_split_prim;
191 }
192
CreateMatmulSplitNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const193 CNodePtr InferenceMatmulSplitFusion::CreateMatmulSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
194 const std::string &pattern_name) const {
195 MS_LOG(DEBUG) << "start create MatmulSplit node";
196 MS_ASSERT(func_graph != nullptr && node != nullptr);
197 auto split_cnode = node->cast<CNodePtr>();
198 MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
199
200 auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
201 MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
202 auto tuple_node = reshape_cnode->input(kIndex2);
203 MS_CHECK_TRUE_RET(tuple_node != nullptr, nullptr);
204
205 auto matmul_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
206 MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr);
207 MS_CHECK_TRUE_RET(matmul_cnode->func_graph() == split_cnode->func_graph(), nullptr);
208
209 auto input_x = matmul_cnode->input(kIndex1);
210 MS_CHECK_TRUE_RET(input_x != nullptr, nullptr);
211 auto input_w = matmul_cnode->input(kIndex2);
212 MS_CHECK_TRUE_RET(input_w != nullptr, nullptr);
213 const std::set<TypeId> support_dtype = {kNumberTypeFloat16, kNumberTypeBFloat16};
214 if (!CheckSupportDataType(input_x, support_dtype) || !CheckMatMulDataFormat(matmul_cnode)) {
215 return nullptr;
216 }
217
218 size_t split_size_len = GetSplitSizeLen(split_cnode);
219 auto matmul_split_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name);
220 std::vector<AnfNodePtr> matmul_split_inputs = {input_x, input_w, tuple_node};
221 auto matmul_split_cnode = func_graph->NewCNode(matmul_split_prim, matmul_split_inputs);
222 MS_EXCEPTION_IF_NULL(matmul_split_cnode);
223
224 matmul_split_cnode->set_fullname_with_scope(matmul_cnode->fullname_with_scope() + "-SplitWithSize");
225 if (node->abstract() != nullptr) {
226 matmul_split_cnode->set_abstract(split_cnode->abstract()->Clone());
227 }
228 visited_cnodes.insert(split_cnode);
229 MS_LOG(DEBUG) << "create MatmulSplit node success.";
230 return matmul_split_cnode;
231 }
232
CreateMatmulBiasAddSplitNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const233 CNodePtr InferenceMatmulSplitFusion::CreateMatmulBiasAddSplitNode(const FuncGraphPtr &func_graph,
234 const AnfNodePtr &node,
235 const std::string &pattern_name) const {
236 MS_LOG(DEBUG) << "start create MatmulBiasAddSplit node";
237 MS_ASSERT(func_graph != nullptr && node != nullptr);
238 auto split_cnode = node->cast<CNodePtr>();
239 MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
240
241 auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
242 MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
243 auto reshape_tuple = reshape_cnode->input(kIndex2);
244 MS_CHECK_TRUE_RET(reshape_tuple != nullptr, nullptr);
245
246 auto biasAdd_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
247 MS_CHECK_TRUE_RET(biasAdd_cnode != nullptr, nullptr);
248 auto matmul_cnode = biasAdd_cnode->input(kIndex1)->cast<CNodePtr>();
249 MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr);
250 MS_CHECK_TRUE_RET(matmul_cnode->func_graph() == split_cnode->func_graph(), {});
251
252 auto matmul_x = matmul_cnode->input(kIndex1);
253 MS_EXCEPTION_IF_NULL(matmul_x);
254 auto matmul_w = matmul_cnode->input(kIndex2);
255 MS_EXCEPTION_IF_NULL(matmul_w);
256 auto input_bias = biasAdd_cnode->input(kIndex2);
257 MS_EXCEPTION_IF_NULL(input_bias);
258 const std::set<TypeId> support_dtype = {kNumberTypeFloat16};
259 if (!CheckSupportDataType(matmul_x, support_dtype) || !CheckMatMulDataFormat(matmul_cnode)) {
260 return nullptr;
261 }
262 size_t split_size_len = GetSplitSizeLen(split_cnode);
263 auto matmul_split_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name);
264 matmul_split_prim->AddAttr("with_bias", MakeValue<bool>(true));
265 std::vector<AnfNodePtr> matmul_split_inputs = {matmul_x, matmul_w, reshape_tuple, input_bias};
266 auto matmul_split_cnode = func_graph->NewCNode(matmul_split_prim, matmul_split_inputs);
267 MS_EXCEPTION_IF_NULL(matmul_split_cnode);
268
269 matmul_split_cnode->set_fullname_with_scope(matmul_cnode->fullname_with_scope() + "-BiasAddSplitWithSize");
270 if (node->abstract() != nullptr) {
271 matmul_split_cnode->set_abstract(split_cnode->abstract()->Clone());
272 }
273 visited_cnodes.insert(split_cnode);
274 MS_LOG(DEBUG) << "create MatmulBiasAddSplit node success.";
275 return matmul_split_cnode;
276 }
277
CreateQuantbatchmatmulSplitNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const278 CNodePtr InferenceMatmulSplitFusion::CreateQuantbatchmatmulSplitNode(const FuncGraphPtr &func_graph,
279 const AnfNodePtr &node,
280 const std::string &pattern_name) const {
281 MS_LOG(DEBUG) << "start create QuantbatchmatmulSplit node";
282 MS_ASSERT(func_graph != nullptr && node != nullptr);
283 auto split_cnode = node->cast<CNodePtr>();
284 MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
285
286 auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
287 MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
288 auto qbmm_tuple = reshape_cnode->input(kIndex2);
289 MS_CHECK_TRUE_RET(qbmm_tuple != nullptr, nullptr);
290 auto qbmm_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
291 MS_CHECK_TRUE_RET(qbmm_cnode != nullptr, nullptr);
292 MS_CHECK_TRUE_RET(qbmm_cnode->func_graph() == split_cnode->func_graph(), nullptr);
293
294 auto input_x = qbmm_cnode->input(kIndex1);
295 MS_EXCEPTION_IF_NULL(input_x);
296 auto input_w = qbmm_cnode->input(kIndex2);
297 MS_EXCEPTION_IF_NULL(input_w);
298 auto input_bias = qbmm_cnode->input(kIndex5);
299 MS_EXCEPTION_IF_NULL(input_bias);
300 auto input_scale = qbmm_cnode->input(kIndex3);
301 MS_EXCEPTION_IF_NULL(input_scale);
302 const std::set<TypeId> support_dtype = {kNumberTypeInt8};
303 if (!CheckSupportDataType(input_x, support_dtype) || !CheckMatMulDataFormat(qbmm_cnode)) {
304 return nullptr;
305 }
306
307 size_t split_size_len = GetSplitSizeLen(split_cnode);
308 auto qbmm_split_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name);
309 std::vector<AnfNodePtr> qbmm_split_inputs = {input_x, input_w, qbmm_tuple, input_bias, input_scale};
310 auto qbmm_split_cnode = func_graph->NewCNode(qbmm_split_prim, qbmm_split_inputs);
311 MS_EXCEPTION_IF_NULL(qbmm_split_cnode);
312
313 qbmm_split_cnode->set_fullname_with_scope(qbmm_cnode->fullname_with_scope() + "-SplitWithSize");
314 if (node->abstract() != nullptr) {
315 qbmm_split_cnode->set_abstract(split_cnode->abstract()->Clone());
316 }
317 visited_cnodes.insert(split_cnode);
318 MS_LOG(DEBUG) << "create QuantbatchmatmulSplit node success.";
319 return qbmm_split_cnode;
320 }
321
CreateGetItemNode(const FuncGraphPtr & func_graph,const CNodePtr & split_cnode,const CNodePtr & matmul_split_cnode,const CNodePtr & silu_cnode,const size_t output_index) const322 CNodePtr InferenceMatmulSplitFusion::CreateGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &split_cnode,
323 const CNodePtr &matmul_split_cnode, const CNodePtr &silu_cnode,
324 const size_t output_index) const {
325 auto manager = func_graph->manager();
326 MS_EXCEPTION_IF_NULL(manager);
327 auto iter = manager->node_users().find(split_cnode);
328 if (iter == manager->node_users().end()) {
329 MS_LOG(DEBUG) << "node has no output in manager";
330 return nullptr;
331 }
332
333 auto output_info_list = iter->second;
334 size_t used_output_index;
335 CNodePtr item_other_node = nullptr;
336 for (const auto &output_info : output_info_list) {
337 auto cnode_name = common::AnfAlgo::GetCNodeName(output_info.first);
338 if (cnode_name == prim::kPrimTupleGetItem->name()) {
339 used_output_index = common::AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
340 if (used_output_index != output_index) {
341 item_other_node = utils::cast<CNodePtr>(output_info.first);
342 break;
343 }
344 }
345 }
346 MS_CHECK_TRUE_RET(item_other_node != nullptr, nullptr);
347 item_other_node->set_input(kRealInputNodeIndexInTupleGetItem, matmul_split_cnode);
348 auto value0 = NewValueNode(MakeValue((int64_t)output_index));
349 value0->set_abstract(value0->value()->ToAbstract());
350 auto new_item_cnode =
351 func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem->Clone()), matmul_split_cnode, value0});
352 MS_CHECK_TRUE_RET(new_item_cnode != nullptr, nullptr);
353 auto silu_node = silu_cnode->cast<AnfNodePtr>();
354 if (silu_node->abstract() != nullptr) {
355 new_item_cnode->set_abstract(silu_node->abstract()->Clone());
356 }
357 MS_LOG(DEBUG) << "create new get_item_node success.";
358 return new_item_cnode;
359 }
360
CreateMatmulSplitSiluNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const361 CNodePtr InferenceMatmulSplitFusion::CreateMatmulSplitSiluNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
362 const std::string &pattern_name) const {
363 MS_LOG(DEBUG) << "start create MatmulSplitSilu node";
364 MS_ASSERT(func_graph != nullptr && node != nullptr);
365 auto silu_cnode = node->cast<CNodePtr>();
366 MS_CHECK_TRUE_RET(silu_cnode != nullptr, nullptr);
367 auto item_cnode = silu_cnode->input(kIndex1)->cast<CNodePtr>();
368 MS_CHECK_TRUE_RET(item_cnode != nullptr, nullptr);
369 auto split_cnode = item_cnode->input(kIndex1)->cast<CNodePtr>();
370 MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
371
372 auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
373 MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
374 auto tuple = reshape_cnode->input(kIndex2);
375 MS_CHECK_TRUE_RET(tuple != nullptr, nullptr);
376 auto matmul_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
377 MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr);
378 MS_CHECK_TRUE_RET(matmul_cnode->func_graph() == split_cnode->func_graph(), nullptr);
379
380 auto x_node = matmul_cnode->input(kIndex1);
381 MS_EXCEPTION_IF_NULL(x_node);
382 auto weight_node = matmul_cnode->input(kIndex2);
383 MS_EXCEPTION_IF_NULL(weight_node);
384 const std::set<TypeId> support_dtype = {kNumberTypeFloat16, kNumberTypeBFloat16};
385 if (!CheckSupportDataType(x_node, support_dtype) || !CheckMatMulDataFormat(matmul_cnode)) {
386 return nullptr;
387 }
388 size_t split_size_len = GetSplitSizeLen(split_cnode);
389 if (split_size_len != kMatmulFfnSplitSizeLen) {
390 MS_LOG(DEBUG) << "MatmulSplitSilu only support ffn output";
391 return nullptr;
392 }
393 auto fusion_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name);
394 size_t output_index = common::AnfAlgo::GetTupleGetItemOutIndex(item_cnode);
395 fusion_prim->AddAttr("silu_position", MakeValue<int32_t>(output_index));
396 std::vector<AnfNodePtr> matmul_split_inputs = {x_node, weight_node, tuple};
397 auto matmul_split_cnode = func_graph->NewCNode(fusion_prim, matmul_split_inputs);
398 MS_EXCEPTION_IF_NULL(matmul_split_cnode);
399
400 auto new_item_cnode = CreateGetItemNode(func_graph, split_cnode, matmul_split_cnode, silu_cnode, output_index);
401 MS_CHECK_TRUE_RET(new_item_cnode != nullptr, nullptr);
402 matmul_split_cnode->set_fullname_with_scope(matmul_cnode->fullname_with_scope() + "-SplitWithSizeSilu");
403 if (node->abstract() != nullptr) {
404 matmul_split_cnode->set_abstract(split_cnode->abstract()->Clone());
405 }
406 visited_cnodes.insert({silu_cnode, split_cnode});
407 MS_LOG(DEBUG) << "create MatmulSplitSilu node success.";
408 return new_item_cnode;
409 }
410
CreateMatmulBiasAddSplitSiluNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const411 CNodePtr InferenceMatmulSplitFusion::CreateMatmulBiasAddSplitSiluNode(const FuncGraphPtr &func_graph,
412 const AnfNodePtr &node,
413 const std::string &pattern_name) const {
414 MS_LOG(DEBUG) << "start create MatmulBiasAddSplitSilu node";
415 MS_ASSERT(func_graph != nullptr && node != nullptr);
416 auto silu_cnode = node->cast<CNodePtr>();
417 MS_CHECK_TRUE_RET(silu_cnode != nullptr, nullptr);
418 auto get_item_cnode = silu_cnode->input(kIndex1)->cast<CNodePtr>();
419 MS_CHECK_TRUE_RET(get_item_cnode != nullptr, nullptr);
420 auto split_cnode = get_item_cnode->input(kIndex1)->cast<CNodePtr>();
421 MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
422
423 auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
424 MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
425 auto tuple_node = reshape_cnode->input(kIndex2);
426 MS_CHECK_TRUE_RET(tuple_node != nullptr, nullptr);
427 auto biasAdd_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
428 MS_CHECK_TRUE_RET(biasAdd_cnode != nullptr, nullptr);
429
430 auto matmul_cnode = biasAdd_cnode->input(kIndex1)->cast<CNodePtr>();
431 MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr);
432 MS_CHECK_TRUE_RET(matmul_cnode->func_graph() == split_cnode->func_graph(), {});
433
434 auto matmul_input = matmul_cnode->input(kIndex1);
435 MS_EXCEPTION_IF_NULL(matmul_input);
436 auto input_w = matmul_cnode->input(kIndex2);
437 MS_EXCEPTION_IF_NULL(input_w);
438 auto input_bias = biasAdd_cnode->input(kIndex2);
439 MS_EXCEPTION_IF_NULL(input_bias);
440 const std::set<TypeId> support_dtype = {kNumberTypeFloat16};
441 if (!CheckSupportDataType(matmul_input, support_dtype) || !CheckMatMulDataFormat(matmul_cnode)) {
442 return nullptr;
443 }
444 size_t split_len = GetSplitSizeLen(split_cnode);
445 if (split_len != kMatmulFfnSplitSizeLen) {
446 MS_LOG(DEBUG) << "MatmulBiasAddSplitSilu only support ffn output";
447 return nullptr;
448 }
449 auto matmul_split_prim = CreateMatmulSplitPrim(split_cnode, split_len, pattern_name);
450 size_t output_index = common::AnfAlgo::GetTupleGetItemOutIndex(get_item_cnode);
451 matmul_split_prim->AddAttr("silu_position", MakeValue<int32_t>(output_index));
452 matmul_split_prim->AddAttr("with_bias", MakeValue<bool>(true));
453 std::vector<AnfNodePtr> matmul_split_inputs = {matmul_input, input_w, tuple_node, input_bias};
454 auto matmul_split_cnode = func_graph->NewCNode(matmul_split_prim, matmul_split_inputs);
455 MS_EXCEPTION_IF_NULL(matmul_split_cnode);
456
457 auto new_item_cnode = CreateGetItemNode(func_graph, split_cnode, matmul_split_cnode, silu_cnode, output_index);
458 MS_CHECK_TRUE_RET(new_item_cnode != nullptr, nullptr);
459 matmul_split_cnode->set_fullname_with_scope(matmul_cnode->fullname_with_scope() + "-BiasAddSplitWithSizeSilu");
460 if (node->abstract() != nullptr) {
461 matmul_split_cnode->set_abstract(split_cnode->abstract()->Clone());
462 }
463 visited_cnodes.insert({silu_cnode, split_cnode});
464 MS_LOG(DEBUG) << "create MatmulBiasAddSplitSilu node success.";
465 return new_item_cnode;
466 }
467
CreateQuantbatchmatmulSplitSiluNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const468 CNodePtr InferenceMatmulSplitFusion::CreateQuantbatchmatmulSplitSiluNode(const FuncGraphPtr &func_graph,
469 const AnfNodePtr &node,
470 const std::string &pattern_name) const {
471 MS_LOG(DEBUG) << "start create QuantbatchmatmulSplitSilu node";
472 MS_ASSERT(func_graph != nullptr && node != nullptr);
473 auto silu_cnode = node->cast<CNodePtr>();
474 MS_CHECK_TRUE_RET(silu_cnode != nullptr, nullptr);
475 auto item_cnode = silu_cnode->input(kIndex1)->cast<CNodePtr>();
476 MS_CHECK_TRUE_RET(item_cnode != nullptr, nullptr);
477 auto split_cnode = item_cnode->input(kIndex1)->cast<CNodePtr>();
478 MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
479
480 auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
481 MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
482 auto reshape_tuple = reshape_cnode->input(kIndex2);
483 MS_CHECK_TRUE_RET(reshape_tuple != nullptr, nullptr);
484 auto qbmm_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
485 MS_CHECK_TRUE_RET(qbmm_cnode != nullptr, nullptr);
486 MS_CHECK_TRUE_RET(qbmm_cnode->func_graph() == split_cnode->func_graph(), nullptr);
487
488 auto qbmm_x = qbmm_cnode->input(kIndex1);
489 MS_EXCEPTION_IF_NULL(qbmm_x);
490 auto qbmm_w = qbmm_cnode->input(kIndex2);
491 MS_EXCEPTION_IF_NULL(qbmm_w);
492 auto input_bias = qbmm_cnode->input(kIndex5);
493 MS_EXCEPTION_IF_NULL(input_bias);
494 auto input_scale = qbmm_cnode->input(kIndex3);
495 MS_EXCEPTION_IF_NULL(input_scale);
496 const std::set<TypeId> support_dtype = {kNumberTypeInt8};
497 if (!CheckSupportDataType(qbmm_x, support_dtype) || !CheckMatMulDataFormat(qbmm_cnode)) {
498 return nullptr;
499 }
500 size_t split_size_len = GetSplitSizeLen(split_cnode);
501 if (split_size_len != kMatmulFfnSplitSizeLen) {
502 MS_LOG(DEBUG) << "QuantbatchmatmulSplitSilu only support ffn output";
503 return nullptr;
504 }
505 auto qbmm_split_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name);
506 size_t output_index = common::AnfAlgo::GetTupleGetItemOutIndex(item_cnode);
507 qbmm_split_prim->AddAttr("silu_position", MakeValue<int32_t>(output_index));
508 std::vector<AnfNodePtr> qbmm_split_inputs = {qbmm_x, qbmm_w, reshape_tuple, input_bias, input_scale};
509 auto qbmm_split_cnode = func_graph->NewCNode(qbmm_split_prim, qbmm_split_inputs);
510 MS_EXCEPTION_IF_NULL(qbmm_split_cnode);
511
512 auto new_item_cnode = CreateGetItemNode(func_graph, split_cnode, qbmm_split_cnode, silu_cnode, output_index);
513 MS_CHECK_TRUE_RET(new_item_cnode != nullptr, nullptr);
514 qbmm_split_cnode->set_fullname_with_scope(qbmm_cnode->fullname_with_scope() + "-SplitWithSizeSilu");
515 if (node->abstract() != nullptr) {
516 qbmm_split_cnode->set_abstract(split_cnode->abstract()->Clone());
517 }
518 visited_cnodes.insert({silu_cnode, split_cnode});
519 MS_LOG(DEBUG) << "create QuantbatchmatmulSplitSilu node success.";
520 return new_item_cnode;
521 }
522
Process(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node) const523 AnfNodePtr InferenceMatmulSplitFusion::Process(const std::string &pattern_name, const FuncGraphPtr &func_graph,
524 const AnfNodePtr &node) const {
525 MS_EXCEPTION_IF_NULL(node);
526 MS_EXCEPTION_IF_NULL(func_graph);
527 auto manager = func_graph->manager();
528 MS_EXCEPTION_IF_NULL(manager);
529
530 auto split_cnode = node->cast<CNodePtr>();
531 MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
532 CNodePtr matmul_split_cnode = nullptr;
533
534 if (pattern_name == kPatternNameMatMulSplit) {
535 matmul_split_cnode = CreateMatmulSplitNode(func_graph, node, pattern_name);
536 }
537 if (pattern_name == kPatternNameMatMulBiasAddSplit) {
538 matmul_split_cnode = CreateMatmulBiasAddSplitNode(func_graph, node, pattern_name);
539 }
540 if (pattern_name == kPatternNameQuantbatchmatmulSplit) {
541 matmul_split_cnode = CreateQuantbatchmatmulSplitNode(func_graph, node, pattern_name);
542 }
543
544 if (pattern_name == kPatternNameMatMulSplitSilu) {
545 matmul_split_cnode = CreateMatmulSplitSiluNode(func_graph, node, pattern_name);
546 }
547 if (pattern_name == kPatternNameMatMulBiasAddSplitSilu) {
548 matmul_split_cnode = CreateMatmulBiasAddSplitSiluNode(func_graph, node, pattern_name);
549 }
550 if (pattern_name == kPatternNameQuantbatchmatmulSplitSilu) {
551 matmul_split_cnode = CreateQuantbatchmatmulSplitSiluNode(func_graph, node, pattern_name);
552 }
553 MS_CHECK_TRUE_RET(matmul_split_cnode != nullptr, nullptr);
554
555 (void)manager->Replace(split_cnode, matmul_split_cnode);
556 MS_LOG(DEBUG) << "MatmulSplit replace success";
557 return matmul_split_cnode;
558 }
559 } // namespace opt
560 } // namespace mindspore
561