1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/core/parallel_matmul_concatenate.h"
18 #include "base/base.h"
19 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
20
21 namespace mindspore::graphkernel {
22 namespace {
GetMatMulTransposeAttr(const CNodePtr & matmul)23 MMAttr GetMatMulTransposeAttr(const CNodePtr &matmul) {
24 auto mm_attrs = common::AnfAlgo::GetCNodePrimitive(matmul)->attrs();
25 if (mm_attrs.count(kTransposeA) == 0 || mm_attrs.count(kTransposeB) == 0) {
26 MS_LOG(WARNING) << "Can not find attr 'transpose_a' or 'transpose_b' in node " << matmul->fullname_with_scope();
27 return std::make_pair(false, false);
28 }
29 auto trans_a = GetValue<bool>(mm_attrs[kTransposeA]);
30 auto trans_b = GetValue<bool>(mm_attrs[kTransposeB]);
31 return std::make_pair(trans_a, trans_b);
32 }
33
NewMatMulNode(const FuncGraphPtr & func_graph,const AnfNodePtrList & matmul_inputs,const CNodePtr & orig_matmul,ShapeVector new_out_shape)34 CNodePtr NewMatMulNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &matmul_inputs, const CNodePtr &orig_matmul,
35 ShapeVector new_out_shape) {
36 auto matmul = func_graph->NewCNode(matmul_inputs);
37 func_graph->AddNode(matmul);
38 MS_EXCEPTION_IF_NULL(matmul);
39 MS_EXCEPTION_IF_NULL(matmul_inputs[1]);
40 auto orig_cnode = matmul_inputs[1]->cast<CNodePtr>();
41 if (orig_cnode != nullptr && orig_cnode->HasAttr(kOutputsFormat)) {
42 auto input_format = GetValue<std::vector<std::string>>(orig_cnode->GetAttr(kOutputsFormat))[0];
43 std::vector<std::string> outputs_formats(AnfUtils::GetOutputTensorNum(matmul), input_format);
44 matmul->AddAttr(kOutputsFormat, MakeValue(outputs_formats));
45 }
46 auto [trans_a, trans_b] = GetMatMulTransposeAttr(orig_matmul);
47 matmul->AddAttr(kTransposeA, MakeValue(trans_a));
48 matmul->AddAttr(kTransposeB, MakeValue(trans_b));
49 std::vector<TypeId> dtypes = {common::AnfAlgo::GetOutputInferDataType(matmul_inputs[1], 0)};
50 std::vector<ShapeVector> shapes = {new_out_shape};
51 common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, shapes, matmul.get());
52 matmul->set_kernel_info(std::make_shared<device::KernelInfo>());
53 return matmul;
54 }
55
GetBatchMNK(const CNodePtr & matmul)56 BMNK GetBatchMNK(const CNodePtr &matmul) {
57 int64_t b = 0;
58 int64_t m = 0;
59 int64_t n = 0;
60 int64_t k = 0;
61 auto shape_a = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex0);
62 auto shape_b = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex1);
63 auto [trans_a, trans_b] = GetMatMulTransposeAttr(matmul);
64 if (shape_a.size() == kDim3 && shape_b.size() == kDim3 && shape_a[kIndex0] == shape_b[kIndex0]) {
65 b = shape_a[kIndex0];
66 (void)shape_a.erase(shape_a.begin());
67 (void)shape_b.erase(shape_b.begin());
68 } else {
69 b = 1;
70 }
71 m = trans_a ? shape_a[kIndex1] : shape_a[kIndex0];
72 k = trans_a ? shape_a[kIndex0] : shape_a[kIndex1];
73 n = trans_b ? shape_b[kIndex0] : shape_b[kIndex1];
74 return std::tuple(b, m, n, k);
75 }
76 } // namespace
77
Analyse(const Group & branches) const78 ConcatenatePlan ParallelMatMulConcatenater::Analyse(const Group &branches) const {
79 ConcatenatePlan target_op_res;
80 Branch b0 = branches[kIndex0];
81 AnfNodePtr shared_input = b0.GetRootData();
82 target_op_res.in_shape = Callback::Instance()->GetOutputInferShape(shared_input, kIndex0);
83 auto matmul = b0.GetTargetOp()->cast<CNodePtr>();
84 MS_EXCEPTION_IF_NULL(matmul);
85 bool is_a_shared = false;
86 for (size_t i = 1; i < matmul->size(); ++i) {
87 auto in = matmul->input(i);
88 if (in == shared_input) {
89 is_a_shared = i == kIndex1;
90 break;
91 }
92 }
93
94 auto [trans_a, trans_b] = GetMatMulTransposeAttr(matmul);
95 int64_t b = 0;
96 int64_t m = 0;
97 int64_t n = 0;
98 int64_t k = 0;
99 std::tie(b, m, n, k) = GetBatchMNK(matmul);
100 if (is_a_shared) {
101 auto shape_b = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex1);
102 size_t rank_b = shape_b.size();
103 auto n_idx = trans_b ? rank_b - kIndex2 : rank_b - kIndex1;
104 target_op_res.concat_in_idx = SizeToInt(n_idx);
105 target_op_res.split_out_idx = SizeToInt(rank_b - kIndex1);
106 int64_t new_n = n * SizeToLong(branches.size());
107 if (rank_b == kDim3) {
108 target_op_res.out_shape = ShapeVector({b, m, new_n});
109 } else {
110 target_op_res.out_shape = ShapeVector({m, new_n});
111 }
112 } else {
113 auto shape_a = common::AnfAlgo::GetPrevNodeOutputInferShape(matmul, kIndex0);
114 size_t rank_a = shape_a.size();
115 auto m_idx = trans_a ? rank_a - kIndex1 : rank_a - kIndex2;
116 target_op_res.concat_in_idx = SizeToInt(m_idx);
117 target_op_res.split_out_idx = SizeToInt(rank_a - kIndex2);
118 auto new_m = m * SizeToLong(branches.size());
119 if (rank_a == kDim3) {
120 target_op_res.out_shape = ShapeVector({b, new_m, n});
121 } else {
122 target_op_res.out_shape = ShapeVector({new_m, n});
123 }
124 }
125 return target_op_res;
126 }
127
CanOpsBeCombined(const AnfNodePtr a,const AnfNodePtr b)128 bool ParallelMatMulConcatenater::CanOpsBeCombined(const AnfNodePtr a, const AnfNodePtr b) {
129 auto matmul1 = a->cast<CNodePtr>();
130 auto matmul2 = b->cast<CNodePtr>();
131 MS_EXCEPTION_IF_NULL(matmul1);
132 MS_EXCEPTION_IF_NULL(matmul2);
133 auto [trans_a1, trans_b1] = GetMatMulTransposeAttr(matmul1);
134 auto [trans_a2, trans_b2] = GetMatMulTransposeAttr(matmul2);
135 return trans_a1 == trans_a2 && trans_b1 == trans_b2;
136 }
137
IsSupportedOp(const AnfNodePtr n)138 bool ParallelMatMulConcatenater::IsSupportedOp(const AnfNodePtr n) {
139 if (n == nullptr || n->cast<CNodePtr>() == nullptr) {
140 return false;
141 }
142 auto prim = GetCNodePrimitive(n);
143 if (prim == nullptr || unsupported_ops_.count(prim->name())) {
144 return false;
145 }
146 return true;
147 }
148
MakeCombinedOp(const Group & branches)149 AnfNodePtr ParallelMatMulConcatenater::MakeCombinedOp(const Group &branches) {
150 Branch b1 = branches[0];
151 AnfNodePtr shared_input = b1.GetRootData();
152 auto matmul_op = b1.GetTargetOp()->cast<CNodePtr>();
153 MS_EXCEPTION_IF_NULL(matmul_op);
154 auto plan = Analyse(branches);
155 plans_.push_back(plan);
156 auto overall_inputs = ReloadInputs(branches, b1.target_op_pos, shared_input);
157 auto matmul = NewMatMulNode(main_graph_, overall_inputs, matmul_op, plan.out_shape);
158 MS_EXCEPTION_IF_CHECK_FAIL(AutoUpdateInfo(matmul), "AutoUpdateInfo fail");
159 return matmul;
160 }
161
IsArgCompatible(const AnfNodePtr a,const AnfNodePtr b)162 bool ParallelMatMulConcatenater::IsArgCompatible(const AnfNodePtr a, const AnfNodePtr b) { return true; }
163
ConcatParallelMatMul(AnfNodePtr root,uint64_t min_num_branches,const std::string & layout,const FuncGraphPtr & func_graph)164 AnfNodePtr ConcatParallelMatMul(AnfNodePtr root, uint64_t min_num_branches, const std::string &layout,
165 const FuncGraphPtr &func_graph) {
166 if (layout == kOpFormat_NCHW) {
167 auto res = ParallelMatMulConcatenater(min_num_branches, layout).Combine(root, func_graph);
168 return res;
169 }
170 MS_LOG(WARNING) << "Not supported combine for layout " << layout;
171 return root;
172 }
173 } // namespace mindspore::graphkernel
174