• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/fusion/batchmatmul_fusion.h"
19 #include <memory>
20 #include <vector>
21 #include <algorithm>
22 #include "mindspore/core/ops/nn_ops.h"
23 #include "mindspore/core/ops/array_ops.h"
24 #include "ops/fusion/mat_mul_fusion.h"
25 #include "tools/common/tensor_util.h"
26 #include "tools/converter/quantizer/quant_param_holder.h"
27 #include "tools/converter/quantizer/quantize_util.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "securec/include/securec.h"
30 #include "nnacl/op_base.h"
31 #include "ops/op_utils.h"
32 
33 namespace mindspore::opt {
34 namespace {
35 constexpr int64_t kFcRightInputDims = 3;
36 constexpr float kFpPrecision = 1e-6;
GetInputAddr(const AnfNodePtr & node,size_t input_index)37 void *GetInputAddr(const AnfNodePtr &node, size_t input_index) {
38   MS_ASSERT(node != nullptr);
39   if (!node->isa<CNode>()) {
40     MS_LOG(ERROR) << "GetInputAddr not cnode";
41     return nullptr;
42   }
43   auto cnode = node->cast<CNodePtr>();
44   if (input_index >= cnode->size()) {
45     MS_LOG(ERROR) << "input index error";
46     return nullptr;
47   }
48   if (cnode->input(input_index)->isa<Parameter>()) {
49     auto param_input = cnode->input(input_index)->cast<ParameterPtr>();
50     MS_CHECK_TRUE_RET(param_input->default_param() != nullptr, nullptr);
51     auto tensor_info = param_input->default_param()->cast<tensor::TensorPtr>();
52     if (tensor_info == nullptr) {
53       MS_LOG(ERROR) << "param not tensor::Tensor";
54       return nullptr;
55     }
56     return tensor_info->data_c();
57   }
58   MS_LOG(ERROR) << "input not parameter";
59   return nullptr;
60 }
GetRightMatmulInputParamter(const CNodePtr & stack_node,const ParameterPtr & rmatmul_input)61 STATUS GetRightMatmulInputParamter(const CNodePtr &stack_node, const ParameterPtr &rmatmul_input) {
62   MS_ASSERT(stack_node != nullptr);
63   MS_ASSERT(rmatmul_input != nullptr);
64   auto joint_fullconnect_size = stack_node->size() - 1;
65   auto fc = stack_node->input(1)->cast<CNodePtr>();
66   MS_CHECK_TRUE_RET(fc != nullptr, lite::RET_NULL_PTR);
67   auto fc_weight = fc->input(kInputIndexTwo)->cast<ParameterPtr>();
68   MS_CHECK_TRUE_RET(fc_weight != nullptr, lite::RET_NULL_PTR);
69   auto fc_weight_param = std::dynamic_pointer_cast<tensor::Tensor>(fc_weight->default_param());
70   MS_CHECK_TRUE_RET(fc_weight_param != nullptr, lite::RET_NULL_PTR);
71   auto tensor_size = fc_weight_param->Size();
72   auto rmatmul_input_shape = fc_weight_param->shape();
73 
74   rmatmul_input_shape.insert(rmatmul_input_shape.begin(), joint_fullconnect_size);
75   std::vector<int64_t> shape_vector(rmatmul_input_shape.begin(), rmatmul_input_shape.end());
76   auto tensor_info = lite::CreateTensorInfo(nullptr, 0, shape_vector, fc_weight_param->data_type());
77   if (tensor_info == nullptr) {
78     MS_LOG(ERROR) << "Create tensor info failed";
79     return RET_ERROR;
80   }
81   for (size_t i = 1; i < joint_fullconnect_size + 1; i++) {
82     auto tensor_addr = GetInputAddr(stack_node->input(i), kInputIndexTwo);
83     if (tensor_addr == nullptr) {
84       MS_LOG(ERROR) << "input tensor addr nullptr";
85       return RET_ERROR;
86     }
87     if (EOK != memcpy_s(static_cast<int8_t *>(tensor_info->data_c()) + (i - 1) * tensor_size,
88                         tensor_info->Size() - (i - 1) * tensor_size, tensor_addr, tensor_size)) {
89       MS_LOG(ERROR) << "memcpy_s data failed";
90       return RET_ERROR;
91     }
92   }
93   auto status = lite::InitParameterFromTensorInfo(rmatmul_input, tensor_info);
94   if (status != RET_OK) {
95     MS_LOG(ERROR) << "init parameter from tensor info failed";
96     return RET_ERROR;
97   }
98   rmatmul_input->set_name(stack_node->fullname_with_scope());
99 
100   return RET_OK;
101 }
102 
BuildMatMulPrim(const CNodePtr & stack_cnode)103 std::shared_ptr<ops::MatMulFusion> BuildMatMulPrim(const CNodePtr &stack_cnode) {
104   auto matmul_cvalue = std::make_shared<ops::MatMulFusion>();
105   if (matmul_cvalue == nullptr) {
106     MS_LOG(ERROR) << "new MatMul failed";
107     return nullptr;
108   }
109   auto matmul_prim_c = matmul_cvalue->GetPrim();
110   MS_CHECK_TRUE_RET(matmul_prim_c != nullptr, nullptr);
111 
112   auto fullconnect_node = stack_cnode->input(1);
113   auto fullconnect_cnode = fullconnect_node->cast<CNodePtr>();
114   MS_CHECK_TRUE_RET(fullconnect_cnode != nullptr, nullptr);
115   auto fc_prim = GetValueNode<PrimitiveCPtr>(fullconnect_cnode->input(0));
116   MS_ASSERT(fc_prim != nullptr);
117 
118   // quant param in QuantParamHolder
119   lite::QuantParamsVector rmatmul_quant_params;
120   lite::QuantParamsVector output_quant_params;
121   auto rmatmul_quant_params_valueptr = fc_prim->GetAttr("quant_params");
122   MS_CHECK_TRUE_RET(rmatmul_quant_params_valueptr != nullptr, nullptr);
123   auto rmatmul_quant_params_holder = rmatmul_quant_params_valueptr->cast<lite::QuantParamHolderPtr>();
124   if (rmatmul_quant_params_holder == nullptr) {
125     MS_LOG(ERROR) << "quant param is invalid.";
126     return nullptr;
127   }
128   rmatmul_quant_params = rmatmul_quant_params_holder->get_input_quant_params();
129   output_quant_params = rmatmul_quant_params_holder->get_output_quant_params();
130 
131   // no bias quantParams
132   rmatmul_quant_params.pop_back();
133   auto quant_params_holder = std::make_shared<lite::QuantParamHolder>(rmatmul_quant_params, output_quant_params);
134   MS_CHECK_TRUE_RET(quant_params_holder != nullptr, nullptr);
135   (void)matmul_prim_c->AddAttr("quant_params", quant_params_holder);
136 
137   // compatible support: quant param in QuantizationParam, only copy output quant param
138   if (fc_prim->HasAttr(lite::quant::kQuantParam)) {
139     auto quantization_param_value = fc_prim->GetAttr(lite::quant::kQuantParam);
140     MS_CHECK_TRUE_MSG(quantization_param_value != nullptr, nullptr, "quantization_param_value is nullptr.");
141     (void)matmul_prim_c->AddAttr(lite::quant::kQuantParam, quantization_param_value);
142   }
143   return matmul_cvalue;
144 }
145 
IsTensorZero(const tensor::TensorPtr & tensor)146 bool IsTensorZero(const tensor::TensorPtr &tensor) {
147   MS_ASSERT(tensor != nullptr);
148   if (tensor->data_type() != TypeId::kNumberTypeFloat32) {
149     return false;
150   }
151   auto data = reinterpret_cast<float *>(tensor->data_c());
152   for (size_t i = 0; i < tensor->DataSize(); i++) {
153     if (data[i] > kFpPrecision) {
154       return false;
155     }
156   }
157   return true;
158 }
159 
IsFCNonBias(const CNodePtr & fc)160 bool IsFCNonBias(const CNodePtr &fc) {
161   MS_ASSERT(fc != nullptr);
162   if (fc->size() == kInputSizeThree) {
163     return true;
164   }
165   auto bias_input = fc->inputs().at(kInputSizeThree);
166   if (utils::isa<CNodePtr>(bias_input)) {
167     return false;
168   } else if (utils::isa<ParameterPtr>(bias_input)) {
169     auto bias_param = utils::cast<ParameterPtr>(bias_input);
170     if (!bias_param->has_default()) {
171       return false;
172     }
173     auto bias_default_param = bias_param->default_param();
174     if (bias_default_param == nullptr || !utils::isa<tensor::TensorPtr>(bias_default_param)) {
175       return false;
176     }
177     auto bias_tensor = utils::cast<tensor::TensorPtr>(bias_default_param);
178     if (!IsTensorZero(bias_tensor)) {
179       return false;
180     }
181   } else if (utils::isa<ValuePtr>(bias_input)) {
182     auto bias_value = utils::cast<ValuePtr>(bias_input);
183     if (!utils::isa<tensor::TensorPtr>(bias_value)) {
184       return false;
185     }
186     auto bias_tensor = utils::cast<tensor::TensorPtr>(bias_value);
187     if (!IsTensorZero(bias_tensor)) {
188       return false;
189     }
190   }
191   return true;
192 }
193 
ConnectTransposeConcat(const AnfNodePtr & node)194 bool ConnectTransposeConcat(const AnfNodePtr &node) {
195   auto cnode = node->cast<CNodePtr>();
196   if (cnode == nullptr) {
197     MS_LOG(ERROR) << "cnode is null";
198     return false;
199   }
200   auto right_transpose_node = cnode->input(1);
201   MS_CHECK_TRUE_RET(right_transpose_node != nullptr, false);
202   auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
203   MS_CHECK_TRUE_RET(right_transpose_cnode != nullptr, false);
204   if (CheckPrimitiveType(right_transpose_cnode, prim::kPrimConcat)) {
205     return true;
206   }
207   auto front_node = right_transpose_cnode->input(1);
208   MS_CHECK_TRUE_RET(front_node != nullptr, false);
209   auto front_cnode = front_node->cast<CNodePtr>();
210   MS_CHECK_TRUE_RET(front_cnode != nullptr, false);
211   if (CheckPrimitiveType(right_transpose_cnode, prim::kPrimTranspose) &&
212       CheckPrimitiveType(front_cnode, prim::kPrimConcat)) {
213     return true;
214   }
215   return false;
216 }
217 
ResetReshapeParameters(const AnfNodePtr & reshape_node)218 int ResetReshapeParameters(const AnfNodePtr &reshape_node) {
219   auto reshape_cnode = reshape_node->cast<CNodePtr>();
220   MS_ASSERT(reshape_cnode != nullptr);
221   auto reshape_shape_param = reshape_cnode->input(kInputIndexTwo)->cast<ParameterPtr>();
222   MS_ASSERT(reshape_shape_param != nullptr);
223   auto shape_tensor = std::dynamic_pointer_cast<tensor::Tensor>(reshape_shape_param->default_param());
224   auto rmatmul_input_shape = shape_tensor->shape();
225 
226   std::vector<int64_t> shape(1, 0);
227   if (rmatmul_input_shape.size() <= 0) {
228     MS_LOG(ERROR) << "Create tensor info failed";
229     return RET_ERROR;
230   } else if (shape[0] < kFcRightInputDims) {
231     if (INT_ADD_OVERFLOW_THRESHOLD(rmatmul_input_shape[0], 1, INT64_MAX)) {
232       MS_LOG(ERROR) << "rmatmul_input_shape[0] overflow: " << rmatmul_input_shape[0];
233       return RET_ERROR;
234     }
235     shape[0] = rmatmul_input_shape[0] + 1;
236   }
237 
238   auto tensor_info = std::make_shared<tensor::Tensor>(shape_tensor->data_type(), shape);
239   if (tensor_info == nullptr) {
240     MS_LOG(ERROR) << "Create tensor info failed";
241     return RET_ERROR;
242   }
243 
244   int *tensor_data = reinterpret_cast<int *>(tensor_info->data_c());
245   tensor_data[0] = 1;
246   int *reshape_data = reinterpret_cast<int *>(shape_tensor->data_c());
247   for (int64_t i = 1; i < shape[0]; ++i) {
248     tensor_data[i] = reshape_data[i - 1];
249   }
250 
251   auto ret = lite::InitParameterFromTensorInfo(reshape_shape_param, tensor_info);
252   if (ret != RET_OK) {
253     MS_LOG(ERROR) << "init parameter from tensor info failed";
254     return RET_ERROR;
255   }
256   return RET_OK;
257 }
258 }  // namespace
259 
DefinePattern() const260 const BaseRef BatchMatMulFusion::DefinePattern() const {
261   auto is_stack = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimStack>);
262   MS_CHECK_TRUE_RET(is_stack != nullptr, {});
263   auto is_fullconnect1 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimFullConnection>);
264   MS_CHECK_TRUE_RET(is_fullconnect1 != nullptr, {});
265   auto is_fullconnect2 = std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimFullConnection>);
266   MS_CHECK_TRUE_RET(is_fullconnect2 != nullptr, {});
267   auto is_seq_var = std::make_shared<SeqVar>();
268   MS_CHECK_TRUE_RET(is_seq_var != nullptr, {});
269   return VectorRef({is_stack, is_fullconnect1, is_fullconnect2, is_seq_var});
270 }
271 
CheckCnodeProper(const CNodePtr & stack_cnode,const CNodePtr & fullconnect_cnode,const CNodePtr & left_slice_cnode) const272 bool BatchMatMulFusion::CheckCnodeProper(const CNodePtr &stack_cnode, const CNodePtr &fullconnect_cnode,
273                                          const CNodePtr &left_slice_cnode) const {
274   if (IsMarkedTrainOp(stack_cnode)) {
275     return false;
276   }
277   // check stack node all inputs must fullconnect
278   for (size_t i = 1; i < stack_cnode->size(); i++) {
279     auto input_node = stack_cnode->input(i);
280     if (!CheckPrimitiveType(input_node, prim::kPrimFullConnection)) {
281       MS_LOG(WARNING) << "batchmatmulfusion stack node all inputs must fullconnect type";
282       return false;
283     }
284   }
285 
286   if (IsMarkedTrainOp(fullconnect_cnode)) {
287     return false;
288   }
289   if (!IsFCNonBias(fullconnect_cnode)) {
290     return false;
291   }
292 
293   if (IsMarkedTrainOp(left_slice_cnode)) {
294     return false;
295   }
296 
297   if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimSliceFusion)) {
298     if (!CheckPrimitiveType(left_slice_cnode, prim::kPrimReshape)) {
299       return false;
300     }
301     auto up_slice_cnode = left_slice_cnode->input(1)->cast<CNodePtr>();
302     if (IsMarkedTrainOp(up_slice_cnode)) {
303       return false;
304     }
305     if (up_slice_cnode == nullptr || !CheckPrimitiveType(up_slice_cnode, prim::kPrimSliceFusion)) {
306       return false;
307     }
308   }
309   return true;
310 }
311 
Process(const FuncGraphPtr & func_graph,const AnfNodePtr & node,const EquivPtr &) const312 const AnfNodePtr BatchMatMulFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
313                                             const EquivPtr &) const {
314   MS_CHECK_TRUE_RET(func_graph != nullptr, nullptr);
315   MS_CHECK_TRUE_RET(node != nullptr, nullptr);
316   auto stack_cnode = node->cast<CNodePtr>();
317   auto fullconnect_node = stack_cnode->input(1);
318   auto fullconnect_cnode = fullconnect_node->cast<CNodePtr>();
319   MS_CHECK_TRUE_RET(fullconnect_cnode != nullptr, nullptr);
320   auto left_slice_node = fullconnect_cnode->input(1);
321   auto left_slice_cnode = left_slice_node->cast<CNodePtr>();
322   MS_CHECK_TRUE_RET(left_slice_cnode != nullptr, nullptr);
323   if (!CheckCnodeProper(stack_cnode, fullconnect_cnode, left_slice_cnode)) {
324     MS_LOG(WARNING) << stack_cnode->fullname_with_scope() << " can't fusion into matmul. Fusion failed";
325     return nullptr;
326   }
327   if (CheckPrimitiveType(left_slice_cnode, prim::kPrimReshape)) {
328     auto &left_reshape_cnode = left_slice_cnode;
329     left_slice_cnode = left_reshape_cnode->input(1)->cast<CNodePtr>();
330   }
331 
332   // slice +fullconnect ->batchmatmul
333   auto left_matmul_input = left_slice_cnode->input(1);
334   auto right_reshape_node = fullconnect_cnode->input(kInputIndexTwo);
335   MS_ASSERT(right_reshape_node != nullptr);
336   auto matmul_cvalue = BuildMatMulPrim(stack_cnode);
337   MS_CHECK_TRUE_RET(matmul_cvalue != nullptr, nullptr);
338   auto matmul_value_node = NewValueNode(matmul_cvalue->GetPrim());
339   MS_CHECK_TRUE_RET(matmul_value_node != nullptr, nullptr);
340   std::vector<AnfNodePtr> matmul_inputs = {matmul_value_node, left_matmul_input};
341 
342   // batchmatmul right node may be const
343   bool right_transpose = false;
344   if (right_reshape_node->isa<Parameter>()) {
345     auto rmatmul_paramter = func_graph->add_parameter();
346     MS_CHECK_TRUE_RET(rmatmul_paramter != nullptr, nullptr);
347     if (GetRightMatmulInputParamter(stack_cnode, rmatmul_paramter) != RET_OK) {
348       MS_LOG(ERROR) << "GetRightMatmulInputParamter failed";
349       return node;
350     }
351     auto prim_matmul = ops::GetOperator<mindspore::ops::MatMulFusion>(matmul_value_node);
352     MS_ASSERT(prim_matmul != nullptr);
353     prim_matmul->set_transpose_b(true);
354     matmul_inputs.push_back(rmatmul_paramter);
355   } else if (ConnectTransposeConcat(right_reshape_node)) {
356     right_transpose = true;
357     auto ret = ResetReshapeParameters(right_reshape_node);
358     if (ret != RET_OK) {
359       MS_LOG(ERROR) << "reset reshape parameters failed";
360       return nullptr;
361     }
362     matmul_inputs.push_back(right_reshape_node);
363   } else {
364     auto right_reshape_cnode = right_reshape_node->cast<CNodePtr>();
365     MS_CHECK_TRUE_RET(right_reshape_cnode != nullptr, nullptr);
366     if (IsMarkedTrainOp(right_reshape_cnode)) {
367       return nullptr;
368     }
369     MS_ASSERT(right_reshape_cnode->size() > 1);
370     auto right_transpose_node = right_reshape_cnode->input(1);
371     MS_CHECK_TRUE_RET(right_transpose_node != nullptr, nullptr);
372     auto right_transpose_cnode = right_transpose_node->cast<CNodePtr>();
373     MS_CHECK_TRUE_RET(right_transpose_cnode != nullptr, nullptr);
374     auto right_slice_node = right_transpose_cnode->input(1);
375     MS_CHECK_TRUE_RET(right_slice_node != nullptr, nullptr);
376     auto right_slice_cnode = right_slice_node->cast<CNodePtr>();
377     MS_CHECK_TRUE_RET(right_slice_cnode != nullptr, nullptr);
378     auto right_matmul_input = right_slice_cnode->input(1);
379     matmul_inputs.push_back(right_matmul_input);
380   }
381   auto matmul_cnode = func_graph->NewCNode(matmul_inputs);
382   MS_CHECK_TRUE_RET(matmul_cnode != nullptr, nullptr);
383   matmul_cnode->set_fullname_with_scope(stack_cnode->fullname_with_scope());
384   MS_CHECK_TRUE_RET(stack_cnode->abstract() != nullptr, nullptr);
385   matmul_cnode->set_abstract(stack_cnode->abstract()->Clone());
386   if (right_transpose) {
387     auto matmul_primitive = ops::GetOperator<ops::MatMulFusion>(matmul_cnode->input(0));
388     matmul_primitive->set_transpose_b(true);
389   }
390   MS_LOG(INFO) << "stack node:" << stack_cnode->fullname_with_scope() << " batchmatmul fusion success";
391   return matmul_cnode;
392 }
393 }  // namespace mindspore::opt
394