1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_INFERENCE_MATMUL_SPLIT_FUSION_H_ 18 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_INFERENCE_MATMUL_SPLIT_FUSION_H_ 19 20 #include <string> 21 #include <memory> 22 #include <map> 23 #include <set> 24 25 #include "plugin/device/ascend/optimizer/ir_fusion/inference_weight_preprocess_utils.h" 26 #include "include/backend/optimizer/pass.h" 27 #include "ir/func_graph.h" 28 #include "ir/anf.h" 29 #include "include/backend/optimizer/helper.h" 30 #include "include/backend/optimizer/optimizer.h" 31 #include "mindspore/core/ops/nn_ops.h" 32 #include "mindspore/core/ops/math_ops.h" 33 #include "mindspore/core/ops/sequence_ops.h" 34 #include "mindspore/core/ops/framework_ops.h" 35 36 namespace mindspore { 37 namespace opt { 38 constexpr auto kMatmulQkvSplitSizeLen = 3; 39 constexpr auto kMatmulFfnSplitSizeLen = 2; 40 41 class InferenceMatmulSplitFusion : public Pass { 42 public: InferenceMatmulSplitFusion()43 InferenceMatmulSplitFusion() : Pass("inference_matmul_split_fusion") {} 44 ~InferenceMatmulSplitFusion() override = default; 45 bool Run(const FuncGraphPtr &graph) override; 46 47 private: 48 std::string GetFusionPatternName(const CNodePtr &cnode) const; 49 std::string GetSplitFusionPatternName(const CNodePtr &cnode) const; 50 bool CheckMatMulDataFormat(const CNodePtr &matmul_cnode) const; 51 size_t GetSplitSizeLen(const CNodePtr &split_cnode) const; 52 PrimitivePtr CreateMatmulSplitPrim(const CNodePtr &split_cnode, size_t split_size_len, const std::string &) const; 53 CNodePtr CreateGetItemNode(const FuncGraphPtr &func_graph, const CNodePtr &split_cnode, 54 const CNodePtr &matmul_split_cnode, const CNodePtr &silu_cnode, 55 const size_t output_index) const; 56 CNodePtr CreateMatmulSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const std::string &) const; 57 CNodePtr CreateMatmulBiasAddSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 58 const std::string &) const; 59 CNodePtr CreateQuantbatchmatmulSplitNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 60 const std::string &) const; 61 CNodePtr CreateMatmulSplitSiluNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const std::string &) const; 62 CNodePtr CreateMatmulBiasAddSplitSiluNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 63 const std::string &) const; 64 CNodePtr CreateQuantbatchmatmulSplitSiluNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, 65 const std::string &) const; 66 bool enable_fusion_silu = false; 67 mutable std::set<CNodePtr> visited_cnodes; 68 69 protected: 70 AnfNodePtr Process(const std::string &pattern_name, const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; 71 const std::string kPrimNameMatmulSplitOut2 = "MatmulSplitOut2"; 72 const std::string kPrimNameMatmulSplitOut3 = "MatmulSplitOut3"; 73 const std::string kPrimNameMatmulSplitSiluOut2 = "MatmulSplitSiluOut2"; 74 const std::string kPrimNameMatmulBiasSplitOut2 = "MatmulBiasSplitOut2"; 75 const std::string kPrimNameMatmulBiasSplitOut3 = "MatmulBiasSplitOut3"; 76 const std::string kPrimNameMatmulBiasSplitSiluOut2 = "MatmulBiasSplitSiluOut2"; 77 const std::string kPrimNameQuantbatchmatmulSplitOut2 = "QuantbatchmatmulSplitOut2"; 78 const std::string kPrimNameQuantbatchmatmulSplitOut3 = "QuantbatchmatmulSplitOut3"; 79 const std::string kPrimNameQuantbatchmatmulSplitSiluOut2 = "QuantbatchmatmulSplitSiluOut2"; 80 81 const std::string kPatternNameMatMulSplit = "MatmulSplit"; 82 const std::string kPatternNameMatMulSplitSilu = "MatmulSplitSilu"; 83 const std::string kPatternNameMatMulBiasAddSplit = "MatmulBiasAddSplit"; 84 const std::string kPatternNameMatMulBiasAddSplitSilu = "MatmulBiasAddSplitSilu"; 85 const std::string kPatternNameQuantbatchmatmulSplit = "QuantbatchmatmulSplit"; 86 const std::string kPatternNameQuantbatchmatmulSplitSilu = "QuantbatchmatmulSplitSilu"; 87 88 std::map<size_t, std::map<std::string, std::string>> PatternPrimMap = { 89 { 90 kMatmulQkvSplitSizeLen, 91 {{kPatternNameMatMulSplit, kPrimNameMatmulSplitOut3}, 92 {kPatternNameMatMulBiasAddSplit, kPrimNameMatmulBiasSplitOut3}, 93 {kPatternNameQuantbatchmatmulSplit, kPrimNameQuantbatchmatmulSplitOut3}}, 94 }, 95 96 {kMatmulFfnSplitSizeLen, 97 {{kPatternNameMatMulSplit, kPrimNameMatmulSplitOut2}, 98 {kPatternNameMatMulSplitSilu, kPrimNameMatmulSplitSiluOut2}, 99 {kPatternNameMatMulBiasAddSplit, kPrimNameMatmulBiasSplitOut2}, 100 {kPatternNameMatMulBiasAddSplitSilu, kPrimNameMatmulBiasSplitSiluOut2}, 101 {kPatternNameQuantbatchmatmulSplit, kPrimNameQuantbatchmatmulSplitOut2}, 102 {kPatternNameQuantbatchmatmulSplitSilu, kPrimNameQuantbatchmatmulSplitSiluOut2}}}}; 103 }; 104 } // namespace opt 105 } // namespace mindspore 106 107 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_INFERENCE_MATMUL_SPLIT_FUSION_H_ 108