1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/batchmatmul_fusion.h"
19 #include <memory>
20 #include <vector>
21 #include <algorithm>
22 #include "mindspore/core/ops/nn_ops.h"
23 #include "mindspore/core/ops/array_ops.h"
24 #include "ops/fusion/mat_mul_fusion.h"
25 #include "tools/common/tensor_util.h"
26 #include "tools/converter/quantizer/quant_param_holder.h"
27 #include "tools/converter/quantizer/quantize_util.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "securec/include/securec.h"
30 #include "nnacl/op_base.h"
31 #include "ops/op_utils.h"
32
33 namespace mindspore::opt {
34 namespace {
35 constexpr int64_t kFcRightInputDims = 3;
36 constexpr float kFpPrecision = 1e-6;
GetInputAddr(const AnfNodePtr & node,size_t input_index)37 void *GetInputAddr(const AnfNodePtr &node, size_t input_index) {
38 MS_ASSERT(node != nullptr);
39 if (!node->isa<CNode>()) {
40 MS_LOG(ERROR) << "GetInputAddr not cnode";
41 return nullptr;
42 }
43 auto cnode = node->cast<CNodePtr>();
44 if (input_index >= cnode->size()) {
45 MS_LOG(ERROR) << "input index error";
46 return nullptr;
47 }
48 if (cnode->input(input_index)->isa<Parameter>()) {
49 auto param_input = cnode->input(input_index)->cast<ParameterPtr>();
50 MS_CHECK_TRUE_RET(param_input->default_param() != nullptr, nullptr);
51 auto tensor_info = param_input->default_param()->cast<tensor::TensorPtr>();
52 if (tensor_info == nullptr) {
53 MS_LOG(ERROR) << "param not tensor::Tensor";
54 return nullptr;
55 }
56 return tensor_info->data_c();
57 }
58 MS_LOG(ERROR) << "input not parameter";
59 return nullptr;
60 }
GetRightMatmulInputParamter(const CNodePtr & stack_node,const ParameterPtr & rmatmul_input)61 STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPtr &rmatmul_input) {
62 MS_ASSERT(stack_node != nullptr);
63 MS_ASSERT(rmatmul_input != nullptr);
64 auto joint_fullconnect_size = stack_node->size() - 1;
65 auto fc = stack_node->input(1)->cast<CNodePtr>();
66 MS_CHECK_TRUE_RET(fc != nullptr, lite::RET_NULL_PTR);
67 auto fc_weight = fc->input(kInputIndexTwo)->cast<ParameterPtr>();
68 MS_CHECK_TRUE_RET(fc_weight != nullptr, lite::RET_NULL_PTR);
69 auto fc_weight_param = std::dynamic_pointer_cast<tensor::Tensor>(fc_weight->default_param());
70 MS_CHECK_TRUE_RET(fc_weight_param != nullptr, lite::RET_NULL_PTR);
71 auto tensor_size = fc_weight_param->Size();
72 auto rmatmul_input_shape = fc_weight_param->shape();
73
74 rmatmul_input_shape.insert(rmatmul_input_shape.begin(), joint_fullconnect_size);
75 std::vector<int64_t> shape_vector(rmatmul_input_shape.begin(), rmatmul_input_shape.end());
76 auto tensor_info = lite::CreateTensorInfo(nullptr, 0, shape_vector, fc_weight_param->data_type());
77 if (tensor_info == nullptr) {
78 MS_LOG(ERROR) << "Create tensor info failed";
79 return RET_ERROR;
80 }
81 for (size_t i = 1; i < joint_fullconnect_size + 1; i++) {
82 auto tensor_addr = GetInputAddr(stack_node->input(i), kInputIndexTwo);
83 if (tensor_addr == nullptr) {
84 MS_LOG(ERROR) << "input tensor addr nullptr";
85 return RET_ERROR;
86 }
87 if (EOK != memcpy_s(static_cast<int8_t *>(tensor_info->data_c()) + (i - 1) * tensor_size,
88 tensor_info->Size() - (i - 1) * tensor_size, tensor_addr, tensor_size)) {
89 MS_LOG(ERROR) << "memcpy_s data failed";
90 return RET_ERROR;
91 }
92 }
93 auto status = lite::InitParameterFromTensorInfo(rmatmul_input, tensor_info);
94 if (status != RET_OK) {
95 MS_LOG(ERROR) << "init parameter from tensor info failed";
96 return RET_ERROR;
97 }
98 rmatmul_input->set_name(stack_node->fullname_with_scope());
99
100 return RET_OK;
101 }
102
BuildMatMulPrim(const CNodePtr & stack_cnode)103 std::shared_ptr<ops::MatMulFusion> BuildMatMulPrim(const CNodePtr &stack_cnode) {
104 auto matmul_cvalue = std::make_shared<ops::MatMulFusion>();
105 if (matmul_cvalue == nullptr) {
106 MS_LOG(ERROR) << "new MatMul failed";
107 return nullptr;
108 }
109 auto matmul_prim_c = matmul_cvalue->GetPrim();
110 MS_CHECK_TRUE_RET(matmul_prim_c != nullptr, nullptr);
111
112 auto fullconnect_node = stack_cnode->input(1);
113 auto fullconnect_cnode = fullconnect_node->cast<CNodePtr>();
114 MS_CHECK_TRUE_RET(fullconnect_cnode != nullptr, nullptr);
115 auto fc_prim = GetValueNode<PrimitiveCPtr>(fullconnect_cnode->input(0));
116 MS_ASSERT(fc_prim != nullptr);
117
118 // quant param in QuantParamHolder
119 lite::QuantParamsVector rmatmul_quant_params;
120 lite::QuantParamsVector output_quant_params;
121 auto rmatmul_quant_params_valueptr = fc_prim->GetAttr("quant_params");
122 MS_CHECK_TRUE_RET(rmatmul_quant_params_valueptr != nullptr, nullptr);
123 auto rmatmul_quant_params_holder = rmatmul_quant_params_valueptr->cast<lite::QuantParamHolderPtr>();
124 if (rmatmul_quant_params_holder == nullptr) {
125 MS_LOG(ERROR) << "quant param is invalid.";
126 return nullptr;
127 }
128 rmatmul_quant_params = rmatmul_quant_params_holder->get_input_quant_params();
129 output_quant_params = rmatmul_quant_params_holder->get_output_quant_params();
130
131 // no bias quantParams
132 rmatmul_quant_params.pop_back();
133 auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(rmatmul_quant_params, output_quant_params);
134 MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
135 (void)matmul_prim_c->AddAttr("quant_params", quant_params_holder);
136
137 // compatible support: quant param in QuantizationParam, only copy output quant param
138 if (fc_prim->HasAttr(lite::quant::kQuantParam)) {
139 auto quantization_param_value = fc_prim->GetAttr(lite::quant::kQuantParam);
140 MS_CHECK_TRUE_MSG(quantization_param_value != nullptr, nullptr, "quantization_param_value is nullptr.");
141 (void)matmul_prim_c->AddAttr(lite::quant::kQuantParam, quantization_param_value);
142 }
143 return matmul_cvalue;
144 }
145
IsTensorZero(const tensor::TensorPtr & tensor)146 bool IsTensorZero(const tensor::TensorPtr &tensor) {
147 MS_ASSERT(tensor != nullptr);
148 if (tensor->data_type() != TypeId::kNumberTypeFloat32) {
149 return false;
150 }
151 auto data = reinterpret_cast<float *>(tensor->data_c());
152 for (size_t i = 0; i < tensor->DataSize(); i++) {
153 if (data[i] > kFpPrecision) {
154 return false;
155 }
156 }
157 return true;
158 }
159
IsFCNonBias(const CNodePtr & fc)160 bool IsFCNonBias(const CNodePtr &fc) {
161 MS_ASSERT(fc != nullptr);
162 if (fc->size() == kInputSizeThree) {
163 return true;
164 }
165 auto bias_input = fc->inputs().at(kInputSizeThree);
166 if (utils::isa<CNodePtr>(bias_input)) {
167 return false;
168 } else if (utils::isa<ParameterPtr>(bias_input)) {
169 auto bias_param = utils::cast<ParameterPtr>(bias_input);
170 if (!bias_param->has_default()) {
171 return false;
172 }
173 auto bias_default_param = bias_param->default_param();
174 if (bias_default_param == nullptr || !utils::isa<tensor::TensorPtr>(bias_default_param)) {
175 return false;
176 }
177 auto bias_tensor = utils::cast<tensor::TensorPtr>(bias_default_param);
178 if (!IsTensorZero(bias_tensor)) {
179 return false;
180 }
181 } else if (utils::isa<ValuePtr>(bias_input)) {
182 auto bias_value = utils::cast<ValuePtr>(bias_input);
183 if (!utils::isa<tensor::TensorPtr>(bias_value)) {
184 return false;
185 }
186 auto bias_tensor = utils::cast<tensor::TensorPtr>(bias_value);
187 if (!IsTensorZero(bias_tensor)) {
188 return false;
189 }
190 }
191 return true;
192 }
193
ConnectTransposeConcat(const AnfNodePtr & node)194 bool ConnectTransposeConcat(const AnfNodePtr &node) {
195 auto cnode = node->cast<CNodePtr>();
196 if (cnode == nullptr) {
197 MS_LOG(ERROR) << "cnode is null";
198 return false;
199 }
200 auto right_transpose_node = cnode->input(1);
201 MS_CHECK_TRUE_RET(right_transpose_node != nullptr, false);
202 auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
203 MS_CHECK_TRUE_RET(right_transpose_cnode != nullptr, false);
204 if (CheckPrimitiveType(right_transpose_cnode, prim::kPrimConcat)) {
205 return true;
206 }
207 auto front_node = right_transpose_cnode->input(1);
208 MS_CHECK_TRUE_RET(front_node != nullptr, false);
209 auto front_cnode = front_node->cast<CNodePtr>();
210 MS_CHECK_TRUE_RET(front_cnode != nullptr, false);
211 if (CheckPrimitiveType(right_transpose_cnode, prim::kPrimTranspose) &&
212 CheckPrimitiveType(front_cnode, prim::kPrimConcat)) {
213 return true;
214 }
215 return false;
216 }
217
ResetReshapeParameters(const AnfNodePtr & reshape_node)218 int ResetReshapeParameters(const AnfNodePtr &reshape_node) {
219 auto reshape_cnode = reshape_node->cast<CNodePtr>();
220 MS_ASSERT(reshape_cnode != nullptr);
221 auto reshape_shape_param = reshape_cnode->input(kInputIndexTwo)->cast<ParameterPtr>();
222 MS_ASSERT(reshape_shape_param != nullptr);
223 auto shape_tensor = std::dynamic_pointer_cast<tensor::Tensor>(reshape_shape_param->default_param());
224 auto rmatmul_input_shape = shape_tensor->shape();
225
226 std::vector<int64_t> shape(1, 0);
227 if (rmatmul_input_shape.size() <= 0) {
228 MS_LOG(ERROR) << "Create tensor info failed";
229 return RET_ERROR;
230 } else if (shape[0] < kFcRightInputDims) {
231 if (INT_ADD_OVERFLOW_THRESHOLD(rmatmul_input_shape[0], 1, INT64_MAX)) {
232 MS_LOG(ERROR) << "rmatmul_input_shape[0] overflow: " << rmatmul_input_shape[0];
233 return RET_ERROR;
234 }
235 shape[0] = rmatmul_input_shape[0] + 1;
236 }
237
238 auto tensor_info = std::make_shared<tensor::Tensor>(shape_tensor->data_type(), shape);
239 if (tensor_info == nullptr) {
240 MS_LOG(ERROR) << "Create tensor info failed";
241 return RET_ERROR;
242 }
243
244 int *tensor_data = reinterpret_cast<int *>(tensor_info->data_c());
245 tensor_data[0] = 1;
246 int *reshape_data = reinterpret_cast<int *>(shape_tensor->data_c());
247 for (int64_t i = 1; i < shape[0]; ++i) {
248 tensor_data[i] = reshape_data[i - 1];
249 }
250
251 auto ret = lite::InitParameterFromTensorInfo(reshape_shape_param, tensor_info);
252 if (ret != RET_OK) {
253 MS_LOG(ERROR) << "init parameter from tensor info failed";
254 return RET_ERROR;
255 }
256 return RET_OK;
257 }
258 } // namespace
259
DefinePattern() const260 const BaseRef BatchMatMulFusion::DefinePattern() const {
261 auto is_stack = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimStack>);
262 MS_CHECK_TRUE_RET(is_stack != nullptr, {});
263 auto is_fullconnect1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimFullConnection>);
264 MS_CHECK_TRUE_RET(is_fullconnect1 != nullptr, {});
265 auto is_fullconnect2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimFullConnection>);
266 MS_CHECK_TRUE_RET(is_fullconnect2 != nullptr, {});
267 auto is_seq_var = std::make_shared<SeqVar>();
268 MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
269 return VectorRef({is_stack, is_fullconnect1, is_fullconnect2, is_seq_var});
270 }
271
CheckCnodeProper(const CNodePtr & stack_cnode,const CNodePtr & fullconnect_cnode,const CNodePtr & left_slice_cnode) const272 bool BatchMatMulFusion::CheckCnodeProper(const CNodePtr &stack_cnode, const CNodePtr &fullconnect_cnode,
273 const CNodePtr &left_slice_cnode) const {
274 if (IsMarkedTrainOp(stack_cnode)) {
275 return false;
276 }
277 // check stack node all inputs must fullconnect
278 for (size_t i = 1; i < stack_cnode->size(); i++) {
279 auto input_node = stack_cnode->input(i);
280 if (!CheckPrimitiveType(input_node, prim::kPrimFullConnection)) {
281 MS_LOG(WARNING) << "batchmatmulfusion stack node all inputs must fullconnect type";
282 return false;
283 }
284 }
285
286 if (IsMarkedTrainOp(fullconnect_cnode)) {
287 return false;
288 }
289 if (!IsFCNonBias(fullconnect_cnode)) {
290 return false;
291 }
292
293 if (IsMarkedTrainOp(left_slice_cnode)) {
294 return false;
295 }
296
297 if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimSliceFusion)) {
298 if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimReshape)) {
299 return false;
300 }
301 auto up_slice_cnode = left_slice_cnode->input(1)->cast<CNodePtr>();
302 if (IsMarkedTrainOp(up_slice_cnode)) {
303 return false;
304 }
305 if (up_slice_cnode == nullptr || !CheckPrimitiveType(up_slice_cnode, prim::kPrimSliceFusion)) {
306 return false;
307 }
308 }
309 return true;
310 }
311
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const312 const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
313 const EquivPtr &) const {
314 MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
315 MS_CHECK_TRUE_RET(node != nullptr, nullptr);
316 auto stack_cnode = node->cast<CNodePtr>();
317 auto fullconnect_node = stack_cnode->input(1);
318 auto fullconnect_cnode = fullconnect_node->cast<CNodePtr>();
319 MS_CHECK_TRUE_RET(fullconnect_cnode != nullptr, nullptr);
320 auto left_slice_node = fullconnect_cnode->input(1);
321 auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
322 MS_CHECK_TRUE_RET(left_slice_cnode != nullptr, nullptr);
323 if (!CheckCnodeProper(stack_cnode, fullconnect_cnode, left_slice_cnode)) {
324 MS_LOG(WARNING) << stack_cnode->fullname_with_scope() << " can't fusion into matmul. Fusion failed";
325 return nullptr;
326 }
327 if (CheckPrimitiveType(left_slice_cnode, prim::kPrimReshape)) {
328 auto &left_reshape_cnode = left_slice_cnode;
329 left_slice_cnode = left_reshape_cnode->input(1)->cast<CNodePtr>();
330 }
331
332 // slice +fullconnect ->batchmatmul
333 auto left_matmul_input = left_slice_cnode->input(1);
334 auto right_reshape_node = fullconnect_cnode->input(kInputIndexTwo);
335 MS_ASSERT(right_reshape_node != nullptr);
336 auto matmul_cvalue = BuildMatMulPrim(stack_cnode);
337 MS_CHECK_TRUE_RET(matmul_cvalue != nullptr, nullptr);
338 auto matmul_value_node = NewValueNode(matmul_cvalue->GetPrim());
339 MS_CHECK_TRUE_RET(matmul_value_node != nullptr, nullptr);
340 std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input};
341
342 // batchmatmul right node may be const
343 bool right_transpose = false;
344 if (right_reshape_node->isa<Parameter>()) {
345 auto rmatmul_paramter = func_graph->add_parameter();
346 MS_CHECK_TRUE_RET(rmatmul_paramter != nullptr, nullptr);
347 if (GetRightMatmulInputParamter(stack_cnode, rmatmul_paramter) != RET_OK) {
348 MS_LOG(ERROR) << "GetRightMatmulInputParamter failed";
349 return node;
350 }
351 auto prim_matmul = ops::GetOperator<mindspore::ops::MatMulFusion>(matmul_value_node);
352 MS_ASSERT(prim_matmul != nullptr);
353 prim_matmul->set_transpose_b(true);
354 matmul_inputs.push_back(rmatmul_paramter);
355 } else if (ConnectTransposeConcat(right_reshape_node)) {
356 right_transpose = true;
357 auto ret = ResetReshapeParameters(right_reshape_node);
358 if (ret != RET_OK) {
359 MS_LOG(ERROR) << "reset reshape parameters failed";
360 return nullptr;
361 }
362 matmul_inputs.push_back(right_reshape_node);
363 } else {
364 auto right_reshape_cnode = right_reshape_node->cast<CNodePtr>();
365 MS_CHECK_TRUE_RET(right_reshape_cnode != nullptr, nullptr);
366 if (IsMarkedTrainOp(right_reshape_cnode)) {
367 return nullptr;
368 }
369 MS_ASSERT(right_reshape_cnode->size() > 1);
370 auto right_transpose_node = right_reshape_cnode->input(1);
371 MS_CHECK_TRUE_RET(right_transpose_node != nullptr, nullptr);
372 auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
373 MS_CHECK_TRUE_RET(right_transpose_cnode != nullptr, nullptr);
374 auto right_slice_node = right_transpose_cnode->input(1);
375 MS_CHECK_TRUE_RET(right_slice_node != nullptr, nullptr);
376 auto right_slice_cnode = right_slice_node->cast<CNodePtr>();
377 MS_CHECK_TRUE_RET(right_slice_cnode != nullptr, nullptr);
378 auto right_matmul_input = right_slice_cnode->input(1);
379 matmul_inputs.push_back(right_matmul_input);
380 }
381 auto matmul_cnode = func_graph->NewCNode(matmul_inputs);
382 MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr);
383 matmul_cnode->set_fullname_with_scope(stack_cnode->fullname_with_scope());
384 MS_CHECK_TRUE_RET(stack_cnode->abstract() != nullptr, nullptr);
385 matmul_cnode->set_abstract(stack_cnode->abstract()->Clone());
386 if (right_transpose) {
387 auto matmul_primitive = ops::GetOperator<ops::MatMulFusion>(matmul_cnode->input(0));
388 matmul_primitive->set_transpose_b(true);
389 }
390 MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success";
391 return matmul_cnode;
392 }
393 } // namespace mindspore::opt
394