• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "plugin/device/ascend/optimizer/ir_fusion/inference_matmul_split_fusion.h"
17 #include <vector>
18 #include <set>
19 #include "plugin/device/ascend/optimizer/common/gllo_utils.h"
20 #include "mindspore/core/ops/nn_ops.h"
21 #include "mindspore/core/ops/math_ops.h"
22 #include "include/backend/optimizer/helper.h"
23 #include "include/backend/anf_runtime_algorithm.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "include/common/utils/utils.h"
26 #include "utils/ms_context.h"
27 #include "utils/trace_base.h"
28 
29 namespace mindspore {
30 namespace opt {
31 
Run(const FuncGraphPtr & graph)32 bool InferenceMatmulSplitFusion::Run(const FuncGraphPtr &graph) {
33   auto kernel_graph = graph->cast<KernelGraphPtr>();
34   MS_EXCEPTION_IF_NULL(kernel_graph);
35   bool changed = false;
36   auto ms_context = MsContext::GetInstance();
37   MS_EXCEPTION_IF_NULL(ms_context);
38   if (!ms_context->IsEnableInferBoost()) {
39     return false;
40   }
41   constexpr auto kInferenceMatmulSplitSiluName = "InferenceMatmulSplitSilu";
42   constexpr auto kInferenceMatmulSplitName = "InferenceMatmulSplit";
43   auto enable_op_list = ms_context->ms_internal_enable_custom_kernel_list();
44   auto enable_fusion =
45     (std::find(enable_op_list.begin(), enable_op_list.end(), kInferenceMatmulSplitName) != enable_op_list.end());
46   if (!enable_fusion) {
47     return false;
48   }
49   enable_fusion_silu =
50     (std::find(enable_op_list.begin(), enable_op_list.end(), kInferenceMatmulSplitSiluName) != enable_op_list.end());
51 
52   std::string pattern_name = "";
53   auto node_list = TopoSort(graph->output());
54   std::reverse(node_list.begin(), node_list.end());
55   for (const auto &node : node_list) {
56     if (node == nullptr || !node->isa<CNode>()) {
57       continue;
58     }
59     auto cnode = node->cast<CNodePtr>();
60     auto node_name = common::AnfAlgo::GetCNodeName(cnode);
61     if (node_name != prim::kPrimSplitWithSize->name() && node_name != prim::kPrimSiLU->name()) {
62       continue;
63     }
64     if (visited_cnodes.find(cnode) != visited_cnodes.end()) {
65       continue;
66     }
67     pattern_name = GetFusionPatternName(cnode);
68     MS_LOG(DEBUG) << "fusion pattern is : " << pattern_name;
69     if (!pattern_name.empty()) {
70       auto new_node = Process(pattern_name, graph, node);
71       changed |= new_node != nullptr;
72     }
73   }
74   return changed;
75 }
76 
GetSplitFusionPatternName(const CNodePtr & cnode) const77 std::string InferenceMatmulSplitFusion::GetSplitFusionPatternName(const CNodePtr &cnode) const {
78   std::string pattern_name = "";
79   auto reshape_node = common::AnfAlgo::GetInputNode(cnode, kIndex0);
80   if (reshape_node == nullptr || !reshape_node->isa<CNode>()) {
81     return "";
82   }
83   auto reshape_node_name = common::AnfAlgo::GetCNodeName(reshape_node);
84   if (reshape_node_name != prim::kPrimReshape->name()) {
85     MS_LOG(DEBUG) << "reshape node name is: " << reshape_node_name;
86     return "";
87   }
88   auto reshape_cnode = reshape_node->cast<CNodePtr>();
89   auto reshape_input_node = common::AnfAlgo::GetInputNode(reshape_cnode, kIndex0);
90   if (reshape_input_node != nullptr && reshape_input_node->isa<CNode>()) {
91     auto reshape_input_name = common::AnfAlgo::GetCNodeName(reshape_input_node);
92     if (reshape_input_name == prim::kPrimMatMul->name()) {
93       MS_LOG(DEBUG) << "process matmul reshape split fusion";
94       pattern_name = kPatternNameMatMulSplit;
95     } else if (reshape_input_name == prim::kPrimQuantBatchMatmul->name()) {
96       MS_LOG(DEBUG) << "process quant_batch_matmul reshape split fusion";
97       pattern_name = kPatternNameQuantbatchmatmulSplit;
98     } else if (reshape_input_name == prim::kPrimAdd->name()) {
99       auto bias_add_cnode = reshape_input_node->cast<CNodePtr>();
100       auto bias_input_node = common::AnfAlgo::GetInputNode(bias_add_cnode, kIndex0);
101       if (bias_input_node->isa<CNode>() &&
102           common::AnfAlgo::GetCNodeName(bias_input_node) == prim::kPrimMatMul->name()) {
103         MS_LOG(DEBUG) << "process matmul biasadd reshape split fusion";
104         pattern_name = kPatternNameMatMulBiasAddSplit;
105       }
106     }
107   }
108   return pattern_name;
109 }
110 
GetFusionPatternName(const CNodePtr & cnode) const111 std::string InferenceMatmulSplitFusion::GetFusionPatternName(const CNodePtr &cnode) const {
112   std::string pattern_name = "";
113   auto cnode_name = common::AnfAlgo::GetCNodeName(cnode);
114   if (cnode_name == prim::kPrimSiLU->name()) {
115     if (!enable_fusion_silu) {
116       MS_LOG(DEBUG) << "disable matmul split silu fusion";
117       return "";
118     }
119     auto silu_input_node = common::AnfAlgo::GetInputNode(cnode, kIndex0);
120     auto silu_input_name = common::AnfAlgo::GetCNodeName(silu_input_node);
121     if (silu_input_name == prim::kPrimTupleGetItem->name()) {
122       auto silu_input_cnode = silu_input_node->cast<CNodePtr>();
123       auto item_input_node = common::AnfAlgo::GetInputNode(silu_input_cnode, kIndex0);
124       auto item_input_name = common::AnfAlgo::GetCNodeName(item_input_node);
125       if (item_input_name == prim::kPrimSplitWithSize->name()) {
126         auto item_input_cnode = item_input_node->cast<CNodePtr>();
127         auto split_pattern_name = GetSplitFusionPatternName(item_input_cnode);
128         if (!split_pattern_name.empty()) {
129           pattern_name = split_pattern_name + "Silu";
130         }
131       }
132     }
133   } else if (cnode_name == prim::kPrimSplitWithSize->name()) {
134     pattern_name = GetSplitFusionPatternName(cnode);
135   }
136   return pattern_name;
137 }
138 
CheckMatMulDataFormat(const CNodePtr & matmul_cnode) const139 bool InferenceMatmulSplitFusion::CheckMatMulDataFormat(const CNodePtr &matmul_cnode) const {
140   MS_CHECK_TRUE_RET(matmul_cnode != nullptr, false);
141   size_t trans_a_index = 0;
142   size_t trans_b_index = 0;
143   auto cnode_name = common::AnfAlgo::GetCNodeName(matmul_cnode);
144   if (cnode_name == prim::kPrimQuantBatchMatmul->name()) {
145     trans_a_index = kIndex6;
146     trans_b_index = kIndex7;
147   } else if (cnode_name == prim::kPrimMatMul->name()) {
148     trans_a_index = kIndex3;
149     trans_b_index = kIndex4;
150   }
151   auto trans_a = matmul_cnode->input(trans_a_index)->cast<ValueNodePtr>();
152   MS_CHECK_TRUE_RET(trans_a != nullptr, false);
153   auto trans_b = matmul_cnode->input(trans_b_index)->cast<ValueNodePtr>();
154   MS_CHECK_TRUE_RET(trans_b != nullptr, false);
155   bool is_trans_a = GetValue<bool>(trans_a->value());
156   bool is_trans_b = GetValue<bool>(trans_b->value());
157   if (!is_trans_a && is_trans_b) {
158     return true;
159   }
160   return false;
161 }
162 
GetSplitSizeLen(const CNodePtr & split_cnode) const163 size_t InferenceMatmulSplitFusion::GetSplitSizeLen(const CNodePtr &split_cnode) const {
164   auto split_size = split_cnode->input(kIndex2)->cast<ValueNodePtr>();
165   if (split_size == nullptr || !split_size->isa<ValueNode>()) {
166     MS_LOG(DEBUG) << "split size node is nullptr";
167     return 0;
168   }
169   auto split_size_shape = GetValue<std::vector<int64_t>>(split_size->value());
170   size_t split_size_len = split_size_shape.size();
171   return split_size_len;
172 }
173 
CreateMatmulSplitPrim(const CNodePtr & split_cnode,size_t split_size_len,const std::string & pattern_name) const174 PrimitivePtr InferenceMatmulSplitFusion::CreateMatmulSplitPrim(const CNodePtr &split_cnode, size_t split_size_len,
175                                                                const std::string &pattern_name) const {
176   PrimitivePtr matmul_split_prim = nullptr;
177   std::string prim_name = "";
178   auto iter = PatternPrimMap.find(split_size_len);
179   if (iter != PatternPrimMap.end()) {
180     auto iter_n = iter->second.find(pattern_name);
181     if (iter_n != iter->second.end()) {
182       prim_name = iter_n->second;
183     }
184   }
185   MS_CHECK_TRUE_RET(!prim_name.empty(), nullptr);
186   matmul_split_prim = std::make_shared<Primitive>(prim_name);
187   MS_CHECK_TRUE_RET(matmul_split_prim != nullptr, nullptr);
188   auto split_size = split_cnode->input(kIndex2)->cast<ValueNodePtr>();
189   matmul_split_prim->AddAttr("n_lens", split_size->value());
190   return matmul_split_prim;
191 }
192 
CreateMatmulSplitNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const193 CNodePtr InferenceMatmulSplitFusion::CreateMatmulSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
194                                                            const std::string &pattern_name) const {
195   MS_LOG(DEBUG) << "start create MatmulSplit node";
196   MS_ASSERT(func_graph != nullptr && node != nullptr);
197   auto split_cnode = node->cast<CNodePtr>();
198   MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
199 
200   auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
201   MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
202   auto tuple_node = reshape_cnode->input(kIndex2);
203   MS_CHECK_TRUE_RET(tuple_node != nullptr, nullptr);
204 
205   auto matmul_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
206   MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr);
207   MS_CHECK_TRUE_RET(matmul_cnode->func_graph() == split_cnode->func_graph(), nullptr);
208 
209   auto input_x = matmul_cnode->input(kIndex1);
210   MS_CHECK_TRUE_RET(input_x != nullptr, nullptr);
211   auto input_w = matmul_cnode->input(kIndex2);
212   MS_CHECK_TRUE_RET(input_w != nullptr, nullptr);
213   const std::set<TypeId> support_dtype = {kNumberTypeFloat16, kNumberTypeBFloat16};
214   if (!CheckSupportDataType(input_x, support_dtype) || !CheckMatMulDataFormat(matmul_cnode)) {
215     return nullptr;
216   }
217 
218   size_t split_size_len = GetSplitSizeLen(split_cnode);
219   auto matmul_split_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name);
220   std::vector<AnfNodePtr> matmul_split_inputs = {input_x, input_w, tuple_node};
221   auto matmul_split_cnode = func_graph->NewCNode(matmul_split_prim, matmul_split_inputs);
222   MS_EXCEPTION_IF_NULL(matmul_split_cnode);
223 
224   matmul_split_cnode->set_fullname_with_scope(matmul_cnode->fullname_with_scope() + "-SplitWithSize");
225   if (node->abstract() != nullptr) {
226     matmul_split_cnode->set_abstract(split_cnode->abstract()->Clone());
227   }
228   visited_cnodes.insert(split_cnode);
229   MS_LOG(DEBUG) << "create MatmulSplit node success.";
230   return matmul_split_cnode;
231 }
232 
CreateMatmulBiasAddSplitNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const233 CNodePtr InferenceMatmulSplitFusion::CreateMatmulBiasAddSplitNode(const FuncGraphPtr &func_graph,
234                                                                   const AnfNodePtr &node,
235                                                                   const std::string &pattern_name) const {
236   MS_LOG(DEBUG) << "start create MatmulBiasAddSplit node";
237   MS_ASSERT(func_graph != nullptr && node != nullptr);
238   auto split_cnode = node->cast<CNodePtr>();
239   MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
240 
241   auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
242   MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
243   auto reshape_tuple = reshape_cnode->input(kIndex2);
244   MS_CHECK_TRUE_RET(reshape_tuple != nullptr, nullptr);
245 
246   auto biasAdd_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
247   MS_CHECK_TRUE_RET(biasAdd_cnode != nullptr, nullptr);
248   auto matmul_cnode = biasAdd_cnode->input(kIndex1)->cast<CNodePtr>();
249   MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr);
250   MS_CHECK_TRUE_RET(matmul_cnode->func_graph() == split_cnode->func_graph(), {});
251 
252   auto matmul_x = matmul_cnode->input(kIndex1);
253   MS_EXCEPTION_IF_NULL(matmul_x);
254   auto matmul_w = matmul_cnode->input(kIndex2);
255   MS_EXCEPTION_IF_NULL(matmul_w);
256   auto input_bias = biasAdd_cnode->input(kIndex2);
257   MS_EXCEPTION_IF_NULL(input_bias);
258   const std::set<TypeId> support_dtype = {kNumberTypeFloat16};
259   if (!CheckSupportDataType(matmul_x, support_dtype) || !CheckMatMulDataFormat(matmul_cnode)) {
260     return nullptr;
261   }
262   size_t split_size_len = GetSplitSizeLen(split_cnode);
263   auto matmul_split_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name);
264   matmul_split_prim->AddAttr("with_bias", MakeValue<bool>(true));
265   std::vector<AnfNodePtr> matmul_split_inputs = {matmul_x, matmul_w, reshape_tuple, input_bias};
266   auto matmul_split_cnode = func_graph->NewCNode(matmul_split_prim, matmul_split_inputs);
267   MS_EXCEPTION_IF_NULL(matmul_split_cnode);
268 
269   matmul_split_cnode->set_fullname_with_scope(matmul_cnode->fullname_with_scope() + "-BiasAddSplitWithSize");
270   if (node->abstract() != nullptr) {
271     matmul_split_cnode->set_abstract(split_cnode->abstract()->Clone());
272   }
273   visited_cnodes.insert(split_cnode);
274   MS_LOG(DEBUG) << "create MatmulBiasAddSplit node success.";
275   return matmul_split_cnode;
276 }
277 
CreateQuantbatchmatmulSplitNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const278 CNodePtr InferenceMatmulSplitFusion::CreateQuantbatchmatmulSplitNode(const FuncGraphPtr &func_graph,
279                                                                      const AnfNodePtr &node,
280                                                                      const std::string &pattern_name) const {
281   MS_LOG(DEBUG) << "start create QuantbatchmatmulSplit node";
282   MS_ASSERT(func_graph != nullptr && node != nullptr);
283   auto split_cnode = node->cast<CNodePtr>();
284   MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
285 
286   auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
287   MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
288   auto qbmm_tuple = reshape_cnode->input(kIndex2);
289   MS_CHECK_TRUE_RET(qbmm_tuple != nullptr, nullptr);
290   auto qbmm_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
291   MS_CHECK_TRUE_RET(qbmm_cnode != nullptr, nullptr);
292   MS_CHECK_TRUE_RET(qbmm_cnode->func_graph() == split_cnode->func_graph(), nullptr);
293 
294   auto input_x = qbmm_cnode->input(kIndex1);
295   MS_EXCEPTION_IF_NULL(input_x);
296   auto input_w = qbmm_cnode->input(kIndex2);
297   MS_EXCEPTION_IF_NULL(input_w);
298   auto input_bias = qbmm_cnode->input(kIndex5);
299   MS_EXCEPTION_IF_NULL(input_bias);
300   auto input_scale = qbmm_cnode->input(kIndex3);
301   MS_EXCEPTION_IF_NULL(input_scale);
302   const std::set<TypeId> support_dtype = {kNumberTypeInt8};
303   if (!CheckSupportDataType(input_x, support_dtype) || !CheckMatMulDataFormat(qbmm_cnode)) {
304     return nullptr;
305   }
306 
307   size_t split_size_len = GetSplitSizeLen(split_cnode);
308   auto qbmm_split_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name);
309   std::vector<AnfNodePtr> qbmm_split_inputs = {input_x, input_w, qbmm_tuple, input_bias, input_scale};
310   auto qbmm_split_cnode = func_graph->NewCNode(qbmm_split_prim, qbmm_split_inputs);
311   MS_EXCEPTION_IF_NULL(qbmm_split_cnode);
312 
313   qbmm_split_cnode->set_fullname_with_scope(qbmm_cnode->fullname_with_scope() + "-SplitWithSize");
314   if (node->abstract() != nullptr) {
315     qbmm_split_cnode->set_abstract(split_cnode->abstract()->Clone());
316   }
317   visited_cnodes.insert(split_cnode);
318   MS_LOG(DEBUG) << "create QuantbatchmatmulSplit node success.";
319   return qbmm_split_cnode;
320 }
321 
CreateGetItemNode(const FuncGraphPtr & func_graph,const CNodePtr & split_cnode,const CNodePtr & matmul_split_cnode,const CNodePtr & silu_cnode,const size_t output_index) const322 CNodePtr InferenceMatmulSplitFusion::CreateGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &split_cnode,
323                                                        const CNodePtr &matmul_split_cnode, const CNodePtr &silu_cnode,
324                                                        const size_t output_index) const {
325   auto manager = func_graph->manager();
326   MS_EXCEPTION_IF_NULL(manager);
327   auto iter = manager->node_users().find(split_cnode);
328   if (iter == manager->node_users().end()) {
329     MS_LOG(DEBUG) << "node has no output in manager";
330     return nullptr;
331   }
332 
333   auto output_info_list = iter->second;
334   size_t used_output_index;
335   CNodePtr item_other_node = nullptr;
336   for (const auto &output_info : output_info_list) {
337     auto cnode_name = common::AnfAlgo::GetCNodeName(output_info.first);
338     if (cnode_name == prim::kPrimTupleGetItem->name()) {
339       used_output_index = common::AnfAlgo::GetTupleGetItemOutIndex(utils::cast<CNodePtr>(output_info.first));
340       if (used_output_index != output_index) {
341         item_other_node = utils::cast<CNodePtr>(output_info.first);
342         break;
343       }
344     }
345   }
346   MS_CHECK_TRUE_RET(item_other_node != nullptr, nullptr);
347   item_other_node->set_input(kRealInputNodeIndexInTupleGetItem, matmul_split_cnode);
348   auto value0 = NewValueNode(MakeValue((int64_t)output_index));
349   value0->set_abstract(value0->value()->ToAbstract());
350   auto new_item_cnode =
351     func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem->Clone()), matmul_split_cnode, value0});
352   MS_CHECK_TRUE_RET(new_item_cnode != nullptr, nullptr);
353   auto silu_node = silu_cnode->cast<AnfNodePtr>();
354   if (silu_node->abstract() != nullptr) {
355     new_item_cnode->set_abstract(silu_node->abstract()->Clone());
356   }
357   MS_LOG(DEBUG) << "create new get_item_node success.";
358   return new_item_cnode;
359 }
360 
CreateMatmulSplitSiluNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const361 CNodePtr InferenceMatmulSplitFusion::CreateMatmulSplitSiluNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
362                                                                const std::string &pattern_name) const {
363   MS_LOG(DEBUG) << "start create MatmulSplitSilu node";
364   MS_ASSERT(func_graph != nullptr && node != nullptr);
365   auto silu_cnode = node->cast<CNodePtr>();
366   MS_CHECK_TRUE_RET(silu_cnode != nullptr, nullptr);
367   auto item_cnode = silu_cnode->input(kIndex1)->cast<CNodePtr>();
368   MS_CHECK_TRUE_RET(item_cnode != nullptr, nullptr);
369   auto split_cnode = item_cnode->input(kIndex1)->cast<CNodePtr>();
370   MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
371 
372   auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
373   MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
374   auto tuple = reshape_cnode->input(kIndex2);
375   MS_CHECK_TRUE_RET(tuple != nullptr, nullptr);
376   auto matmul_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
377   MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr);
378   MS_CHECK_TRUE_RET(matmul_cnode->func_graph() == split_cnode->func_graph(), nullptr);
379 
380   auto x_node = matmul_cnode->input(kIndex1);
381   MS_EXCEPTION_IF_NULL(x_node);
382   auto weight_node = matmul_cnode->input(kIndex2);
383   MS_EXCEPTION_IF_NULL(weight_node);
384   const std::set<TypeId> support_dtype = {kNumberTypeFloat16, kNumberTypeBFloat16};
385   if (!CheckSupportDataType(x_node, support_dtype) || !CheckMatMulDataFormat(matmul_cnode)) {
386     return nullptr;
387   }
388   size_t split_size_len = GetSplitSizeLen(split_cnode);
389   if (split_size_len != kMatmulFfnSplitSizeLen) {
390     MS_LOG(DEBUG) << "MatmulSplitSilu only support ffn output";
391     return nullptr;
392   }
393   auto fusion_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name);
394   size_t output_index = common::AnfAlgo::GetTupleGetItemOutIndex(item_cnode);
395   fusion_prim->AddAttr("silu_position", MakeValue<int32_t>(output_index));
396   std::vector<AnfNodePtr> matmul_split_inputs = {x_node, weight_node, tuple};
397   auto matmul_split_cnode = func_graph->NewCNode(fusion_prim, matmul_split_inputs);
398   MS_EXCEPTION_IF_NULL(matmul_split_cnode);
399 
400   auto new_item_cnode = CreateGetItemNode(func_graph, split_cnode, matmul_split_cnode, silu_cnode, output_index);
401   MS_CHECK_TRUE_RET(new_item_cnode != nullptr, nullptr);
402   matmul_split_cnode->set_fullname_with_scope(matmul_cnode->fullname_with_scope() + "-SplitWithSizeSilu");
403   if (node->abstract() != nullptr) {
404     matmul_split_cnode->set_abstract(split_cnode->abstract()->Clone());
405   }
406   visited_cnodes.insert({silu_cnode, split_cnode});
407   MS_LOG(DEBUG) << "create MatmulSplitSilu node success.";
408   return new_item_cnode;
409 }
410 
CreateMatmulBiasAddSplitSiluNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const411 CNodePtr InferenceMatmulSplitFusion::CreateMatmulBiasAddSplitSiluNode(const FuncGraphPtr &func_graph,
412                                                                       const AnfNodePtr &node,
413                                                                       const std::string &pattern_name) const {
414   MS_LOG(DEBUG) << "start create MatmulBiasAddSplitSilu node";
415   MS_ASSERT(func_graph != nullptr && node != nullptr);
416   auto silu_cnode = node->cast<CNodePtr>();
417   MS_CHECK_TRUE_RET(silu_cnode != nullptr, nullptr);
418   auto get_item_cnode = silu_cnode->input(kIndex1)->cast<CNodePtr>();
419   MS_CHECK_TRUE_RET(get_item_cnode != nullptr, nullptr);
420   auto split_cnode = get_item_cnode->input(kIndex1)->cast<CNodePtr>();
421   MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
422 
423   auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
424   MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
425   auto tuple_node = reshape_cnode->input(kIndex2);
426   MS_CHECK_TRUE_RET(tuple_node != nullptr, nullptr);
427   auto biasAdd_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
428   MS_CHECK_TRUE_RET(biasAdd_cnode != nullptr, nullptr);
429 
430   auto matmul_cnode = biasAdd_cnode->input(kIndex1)->cast<CNodePtr>();
431   MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr);
432   MS_CHECK_TRUE_RET(matmul_cnode->func_graph() == split_cnode->func_graph(), {});
433 
434   auto matmul_input = matmul_cnode->input(kIndex1);
435   MS_EXCEPTION_IF_NULL(matmul_input);
436   auto input_w = matmul_cnode->input(kIndex2);
437   MS_EXCEPTION_IF_NULL(input_w);
438   auto input_bias = biasAdd_cnode->input(kIndex2);
439   MS_EXCEPTION_IF_NULL(input_bias);
440   const std::set<TypeId> support_dtype = {kNumberTypeFloat16};
441   if (!CheckSupportDataType(matmul_input, support_dtype) || !CheckMatMulDataFormat(matmul_cnode)) {
442     return nullptr;
443   }
444   size_t split_len = GetSplitSizeLen(split_cnode);
445   if (split_len != kMatmulFfnSplitSizeLen) {
446     MS_LOG(DEBUG) << "MatmulBiasAddSplitSilu only support ffn output";
447     return nullptr;
448   }
449   auto matmul_split_prim = CreateMatmulSplitPrim(split_cnode, split_len, pattern_name);
450   size_t output_index = common::AnfAlgo::GetTupleGetItemOutIndex(get_item_cnode);
451   matmul_split_prim->AddAttr("silu_position", MakeValue<int32_t>(output_index));
452   matmul_split_prim->AddAttr("with_bias", MakeValue<bool>(true));
453   std::vector<AnfNodePtr> matmul_split_inputs = {matmul_input, input_w, tuple_node, input_bias};
454   auto matmul_split_cnode = func_graph->NewCNode(matmul_split_prim, matmul_split_inputs);
455   MS_EXCEPTION_IF_NULL(matmul_split_cnode);
456 
457   auto new_item_cnode = CreateGetItemNode(func_graph, split_cnode, matmul_split_cnode, silu_cnode, output_index);
458   MS_CHECK_TRUE_RET(new_item_cnode != nullptr, nullptr);
459   matmul_split_cnode->set_fullname_with_scope(matmul_cnode->fullname_with_scope() + "-BiasAddSplitWithSizeSilu");
460   if (node->abstract() != nullptr) {
461     matmul_split_cnode->set_abstract(split_cnode->abstract()->Clone());
462   }
463   visited_cnodes.insert({silu_cnode, split_cnode});
464   MS_LOG(DEBUG) << "create MatmulBiasAddSplitSilu node success.";
465   return new_item_cnode;
466 }
467 
CreateQuantbatchmatmulSplitSiluNode(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const std::string & pattern_name) const468 CNodePtr InferenceMatmulSplitFusion::CreateQuantbatchmatmulSplitSiluNode(const FuncGraphPtr &func_graph,
469                                                                          const AnfNodePtr &node,
470                                                                          const std::string &pattern_name) const {
471   MS_LOG(DEBUG) << "start create QuantbatchmatmulSplitSilu node";
472   MS_ASSERT(func_graph != nullptr && node != nullptr);
473   auto silu_cnode = node->cast<CNodePtr>();
474   MS_CHECK_TRUE_RET(silu_cnode != nullptr, nullptr);
475   auto item_cnode = silu_cnode->input(kIndex1)->cast<CNodePtr>();
476   MS_CHECK_TRUE_RET(item_cnode != nullptr, nullptr);
477   auto split_cnode = item_cnode->input(kIndex1)->cast<CNodePtr>();
478   MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
479 
480   auto reshape_cnode = split_cnode->input(kIndex1)->cast<CNodePtr>();
481   MS_CHECK_TRUE_RET(reshape_cnode != nullptr, nullptr);
482   auto reshape_tuple = reshape_cnode->input(kIndex2);
483   MS_CHECK_TRUE_RET(reshape_tuple != nullptr, nullptr);
484   auto qbmm_cnode = reshape_cnode->input(kIndex1)->cast<CNodePtr>();
485   MS_CHECK_TRUE_RET(qbmm_cnode != nullptr, nullptr);
486   MS_CHECK_TRUE_RET(qbmm_cnode->func_graph() == split_cnode->func_graph(), nullptr);
487 
488   auto qbmm_x = qbmm_cnode->input(kIndex1);
489   MS_EXCEPTION_IF_NULL(qbmm_x);
490   auto qbmm_w = qbmm_cnode->input(kIndex2);
491   MS_EXCEPTION_IF_NULL(qbmm_w);
492   auto input_bias = qbmm_cnode->input(kIndex5);
493   MS_EXCEPTION_IF_NULL(input_bias);
494   auto input_scale = qbmm_cnode->input(kIndex3);
495   MS_EXCEPTION_IF_NULL(input_scale);
496   const std::set<TypeId> support_dtype = {kNumberTypeInt8};
497   if (!CheckSupportDataType(qbmm_x, support_dtype) || !CheckMatMulDataFormat(qbmm_cnode)) {
498     return nullptr;
499   }
500   size_t split_size_len = GetSplitSizeLen(split_cnode);
501   if (split_size_len != kMatmulFfnSplitSizeLen) {
502     MS_LOG(DEBUG) << "QuantbatchmatmulSplitSilu only support ffn output";
503     return nullptr;
504   }
505   auto qbmm_split_prim = CreateMatmulSplitPrim(split_cnode, split_size_len, pattern_name);
506   size_t output_index = common::AnfAlgo::GetTupleGetItemOutIndex(item_cnode);
507   qbmm_split_prim->AddAttr("silu_position", MakeValue<int32_t>(output_index));
508   std::vector<AnfNodePtr> qbmm_split_inputs = {qbmm_x, qbmm_w, reshape_tuple, input_bias, input_scale};
509   auto qbmm_split_cnode = func_graph->NewCNode(qbmm_split_prim, qbmm_split_inputs);
510   MS_EXCEPTION_IF_NULL(qbmm_split_cnode);
511 
512   auto new_item_cnode = CreateGetItemNode(func_graph, split_cnode, qbmm_split_cnode, silu_cnode, output_index);
513   MS_CHECK_TRUE_RET(new_item_cnode != nullptr, nullptr);
514   qbmm_split_cnode->set_fullname_with_scope(qbmm_cnode->fullname_with_scope() + "-SplitWithSizeSilu");
515   if (node->abstract() != nullptr) {
516     qbmm_split_cnode->set_abstract(split_cnode->abstract()->Clone());
517   }
518   visited_cnodes.insert({silu_cnode, split_cnode});
519   MS_LOG(DEBUG) << "create QuantbatchmatmulSplitSilu node success.";
520   return new_item_cnode;
521 }
522 
Process(const std::string & pattern_name,const FuncGraphPtr & func_graph,const AnfNodePtr & node) const523 AnfNodePtr InferenceMatmulSplitFusion::Process(const std::string &pattern_name, const FuncGraphPtr &func_graph,
524                                                const AnfNodePtr &node) const {
525   MS_EXCEPTION_IF_NULL(node);
526   MS_EXCEPTION_IF_NULL(func_graph);
527   auto manager = func_graph->manager();
528   MS_EXCEPTION_IF_NULL(manager);
529 
530   auto split_cnode = node->cast<CNodePtr>();
531   MS_CHECK_TRUE_RET(split_cnode != nullptr, nullptr);
532   CNodePtr matmul_split_cnode = nullptr;
533 
534   if (pattern_name == kPatternNameMatMulSplit) {
535     matmul_split_cnode = CreateMatmulSplitNode(func_graph, node, pattern_name);
536   }
537   if (pattern_name == kPatternNameMatMulBiasAddSplit) {
538     matmul_split_cnode = CreateMatmulBiasAddSplitNode(func_graph, node, pattern_name);
539   }
540   if (pattern_name == kPatternNameQuantbatchmatmulSplit) {
541     matmul_split_cnode = CreateQuantbatchmatmulSplitNode(func_graph, node, pattern_name);
542   }
543 
544   if (pattern_name == kPatternNameMatMulSplitSilu) {
545     matmul_split_cnode = CreateMatmulSplitSiluNode(func_graph, node, pattern_name);
546   }
547   if (pattern_name == kPatternNameMatMulBiasAddSplitSilu) {
548     matmul_split_cnode = CreateMatmulBiasAddSplitSiluNode(func_graph, node, pattern_name);
549   }
550   if (pattern_name == kPatternNameQuantbatchmatmulSplitSilu) {
551     matmul_split_cnode = CreateQuantbatchmatmulSplitSiluNode(func_graph, node, pattern_name);
552   }
553   MS_CHECK_TRUE_RET(matmul_split_cnode != nullptr, nullptr);
554 
555   (void)manager->Replace(split_cnode, matmul_split_cnode);
556   MS_LOG(DEBUG) << "MatmulSplit replace success";
557   return matmul_split_cnode;
558 }
559 }  // namespace opt
560 }  // namespace mindspore
561