1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18
19 #include "tools/converter/quantizer/weight_quantizer.h"
20 #include <list>
21 #include <string>
22 #include <utility>
23 #include <set>
24 #include "mindspore/core/ops/conv_pool_ops.h"
25 #include "mindspore/core/ops/nn_optimizer_ops.h"
26 #include "mindspore/core/ops/nn_ops.h"
27 #include "mindspore/core/ops/other_ops.h"
28 #include "mindspore/core/ops/math_ops.h"
29 #include "mindspore/core/ops/lite_ops.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "tools/optimizer/common/gllo_utils.h"
32 #include "src/common/log_util.h"
33 #include "tools/converter/quantizer/fse_encoder.h"
34 #include "tools/converter/quantizer/tensor_compressor.h"
35 #include "tools/converter/quantizer/cluster_quantization.h"
36 #include "tools/converter/quantizer/mixed_bit_weight_quantization.h"
37 #include "tools/converter/quantizer/fixed_bit_weight_quantization.h"
38 #include "tools/converter/quantizer/insert_quant_node_manager.h"
39 #include "tools/common/node_util.h"
40 #include "src/common/quant_utils.h"
41 #include "tools/converter/quantizer/gptq_quantizer.h"
42
43 namespace mindspore::lite::quant {
44 namespace {
ConvertParameterFp16TensorToFp32(const ParameterPtr & parameter)45 tensor::TensorPtr ConvertParameterFp16TensorToFp32(const ParameterPtr ¶meter) {
46 if (!parameter->has_default()) {
47 MS_LOG(WARNING) << parameter->fullname_with_scope() << " not has_default";
48 return nullptr;
49 }
50 auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
51 if (tensor_info == nullptr) {
52 MS_LOG(WARNING) << "default_param can not cast to tensor::Tensor";
53 return nullptr;
54 }
55 if (tensor_info->data_type() == kNumberTypeFloat16) {
56 MS_LOG(INFO) << "convert " << parameter->fullname_with_scope() << " from fp16 to fp32.";
57 auto data = static_cast<float16 *>(tensor_info->data_c());
58 std::vector<float> fp32_data(tensor_info->DataSize());
59 for (size_t j = 0; j < tensor_info->DataSize(); j++) {
60 fp32_data[j] = mindspore::Float16::ToFloat32(data[j]);
61 }
62 mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
63 kNumberTypeFloat32, tensor_info->shape_c(), fp32_data.data(), fp32_data.size() * sizeof(float));
64 parameter->set_default_param(tensor_ptr);
65 parameter->set_abstract(tensor_ptr->ToAbstract());
66 return tensor_ptr;
67 }
68 return tensor_info;
69 }
70 } // namespace
WeightQuant(const FuncGraphPtr & func_graph,const std::set<PrimitivePtr> & support_weight_quant_types,const std::set<PrimitivePtr> & per_layer_types,const std::set<PrimitivePtr> & symmetric_types,bool compression)71 int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph,
72 const std::set<PrimitivePtr> &support_weight_quant_types,
73 const std::set<PrimitivePtr> &per_layer_types,
74 const std::set<PrimitivePtr> &symmetric_types, bool compression) {
75 for (auto &cnode : func_graph->GetOrderedCnodes()) {
76 auto ret =
77 WeightQuantPerCNode(func_graph, cnode, support_weight_quant_types, per_layer_types, symmetric_types, compression);
78 if (ret != RET_OK) {
79 MS_LOG(ERROR) << cnode->fullname_with_scope() << " execute weight quantize error.";
80 return RET_ERROR;
81 }
82 }
83 return RET_OK;
84 }
85
WeightQuantPerCNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::set<PrimitivePtr> & support_weight_quant_types,const std::set<PrimitivePtr> & per_layer_types,const std::set<PrimitivePtr> & symmetric_types,bool compression)86 int WeightQuantizer::WeightQuantPerCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
87 const std::set<PrimitivePtr> &support_weight_quant_types,
88 const std::set<PrimitivePtr> &per_layer_types,
89 const std::set<PrimitivePtr> &symmetric_types, bool compression) {
90 auto primitive = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
91 if (primitive == nullptr) {
92 MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
93 return RET_OK;
94 }
95 auto op_name = cnode->fullname_with_scope();
96 if (skip_quant_node_.find(op_name) != skip_quant_node_.end()) {
97 MS_LOG(INFO) << op_name << " is skip dynamic quant.";
98 return RET_OK;
99 }
100 if (!CheckNodeInSet(cnode, support_weight_quant_types)) {
101 MS_LOG(INFO) << cnode->fullname_with_scope() << " of type: " << primitive->name() << " dont need weight quant.";
102 return RET_OK;
103 }
104
105 // Ascend ON_THE_FLY quant only support Gather followed by BatchMatMul or MatMul.
106 if (ascend_backend_ && dequant_strategy_ == ON_THE_FLY && opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
107 auto support_gather_followed_primitive_types = {prim::kPrimBatchMatMul, prim::kPrimMatMul, prim::kPrimFFN};
108 if (!CheckFollowedNodeInSet(func_graph, cnode, support_gather_followed_primitive_types)) {
109 MS_LOG(INFO) << "In Ascend ON_THE_FLY quant mode, The Gather followed cnode is not BatchMatMul or MatMul, "
110 << cnode->fullname_with_scope() << " dont need weight quant";
111 return RET_OK;
112 }
113 }
114
115 // Init weight quant index.
116 std::vector<int> weight_indices;
117 if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) {
118 weight_indices = {2, 3};
119 } else if (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
120 weight_indices = {4, 6};
121 } else if (opt::CheckPrimitiveType(cnode, prim::kPrimApplyMomentum)) {
122 weight_indices = {2};
123 } else {
124 for (size_t i = 1; i < cnode->size(); ++i) {
125 weight_indices.push_back(i);
126 }
127 }
128
129 if (linear_quant_) {
130 bool is_compression = compression && !is_mixed_bit_ && enable_encode_;
131 auto ret = LinearQuant(func_graph, cnode, per_layer_types, symmetric_types, weight_indices, is_compression);
132 if (ret != RET_OK) {
133 MS_LOG(ERROR) << cnode->fullname_with_scope() << " execute linear weight quantize error.";
134 return RET_ERROR;
135 }
136 } else {
137 ClusterQuantization cluster;
138 auto ret = cluster.KMeansQuantization(cnode, weight_indices);
139 if (ret != RET_OK) {
140 MS_LOG(ERROR) << cnode->fullname_with_scope() << " execute k-means weight quantize error.";
141 return RET_ERROR;
142 }
143 }
144 return RET_OK;
145 }
146
PreLinearQuant(const CNodePtr & cnode,int idx,const AnfNodePtr & input,ParameterPtr * parameter,tensor::TensorPtr * tensor_info)147 int WeightQuantizer::PreLinearQuant(const CNodePtr &cnode, int idx, const AnfNodePtr &input, ParameterPtr *parameter,
148 tensor::TensorPtr *tensor_info) {
149 CHECK_NULL_RETURN(parameter);
150 CHECK_NULL_RETURN(tensor_info);
151 GetParameterAndTensor(input, parameter, tensor_info);
152 if (*parameter == nullptr || *tensor_info == nullptr ||
153 (*tensor_info)->compression_type() != mindspore::kNoCompression) {
154 MS_LOG(INFO) << "This op " << input->fullname_with_scope() << " dont need quant weight";
155 return RET_NO_CHANGE;
156 }
157 *tensor_info = ConvertParameterFp16TensorToFp32(*parameter);
158 if ((*tensor_info) == nullptr || (*tensor_info)->data_type() != TypeId::kNumberTypeFloat32) {
159 MS_LOG(INFO) << "This op " << input->fullname_with_scope() << " is null or dtype is not fp32.";
160 return RET_NO_CHANGE;
161 }
162 auto preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32((*tensor_info)->shape()));
163 if (quant_strategy_ != nullptr && !quant_strategy_->CanTensorQuantized(cnode, input, preferred_dim)) {
164 MS_LOG(INFO) << input->fullname_with_scope() << " will not quantify";
165 return RET_NO_CHANGE;
166 }
167 return RET_OK;
168 }
169
LinearQuant(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::set<PrimitivePtr> & per_layer_types,const std::set<PrimitivePtr> & symmetric_types,const std::vector<int> & weight_indices,bool compression)170 int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
171 const std::set<PrimitivePtr> &per_layer_types,
172 const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices,
173 bool compression) {
174 CHECK_NULL_RETURN(cnode);
175 // Avoid affecting other operators
176 auto tmp_weight_quant_type = weight_quant_type_;
177 if (CheckNodeInSet(cnode, per_layer_types)) {
178 tmp_weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
179 }
180 bool symmetric = false;
181 int q_min = quant_min_;
182 int q_max = quant_max_;
183 if (CheckNodeInSet(cnode, symmetric_types)) {
184 symmetric = true;
185 q_min = symmetric_quant_min_;
186 q_max = symmetric_quant_max_;
187 }
188
189 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
190 CHECK_NULL_RETURN(primitive);
191 auto manager = mindspore::Manage(func_graph, true);
192 CHECK_NULL_RETURN(manager);
193 for (auto idx : weight_indices) {
194 auto input = cnode->input(idx);
195 ParameterPtr parameter;
196 tensor::TensorPtr tensor_info;
197 auto status = PreLinearQuant(cnode, idx, input, ¶meter, &tensor_info);
198 if (status == RET_NO_CHANGE) {
199 continue;
200 } else if (status != RET_OK) {
201 MS_LOG(ERROR) << input->fullname_with_scope() << " pre linear quant failed : " << status;
202 return status;
203 }
204 // support for matmul shared weight
205 auto node_map = manager->node_users();
206 auto node_user = node_map[input];
207 if (node_user.size() > 1 && opt::CheckPrimitiveType(cnode, prim::kPrimMatMulFusion)) {
208 MS_LOG(INFO) << input->fullname_with_scope() << " is shared weight.";
209 tmp_weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
210 }
211 // linear quant
212 int preferred_dim;
213 // For MOE Linear, the preferred dim is get by the batch_matmul node, which is followed by gather node.
214 if (ascend_backend_ && dequant_strategy_ == ON_THE_FLY && opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
215 auto support_gather_followed_primitive_types = {prim::kPrimBatchMatMul, prim::kPrimMatMul, prim::kPrimFFN};
216 if (!CheckFollowedNodeInSet(func_graph, cnode, support_gather_followed_primitive_types)) {
217 MS_LOG(INFO) << "In Ascend ON_THE_FLY quant mode, The Gather followed cnode is not BatchMatMul or MatMul, "
218 << cnode->fullname_with_scope() << " dont need weight quant";
219 return RET_OK;
220 }
221 preferred_dim = GetFollowedNodePreferredDim(func_graph, cnode, ConvertShapeVectorToInt32(tensor_info->shape()));
222 } else {
223 preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
224 }
225 if (is_mixed_bit_) {
226 status = DoMixBitQuant(cnode, parameter, idx, tensor_info, preferred_dim, tmp_weight_quant_type, symmetric);
227 } else {
228 FixedBitWeightQuantization fixed_bit_quant;
229 status =
230 fixed_bit_quant.QuantFilter(parameter, tensor_info, primitive, quant_type_, q_max, q_min, bit_num_,
231 tmp_weight_quant_type, type_id_, preferred_dim, symmetric, false, bias_correction_);
232 }
233 if (status == RET_NO_CHANGE) {
234 continue;
235 } else if (status != RET_OK) {
236 MS_LOG(ERROR) << "QuantFilter failed : " << status;
237 return status;
238 }
239 // Post linear quant
240 if (compression) {
241 status = DoCompression(cnode, parameter, tensor_info);
242 if (status != RET_OK) {
243 MS_LOG(ERROR) << cnode->fullname_with_scope() << " compression failed.";
244 return status;
245 }
246 }
247 if (dequant_strategy_ == ON_THE_FLY) {
248 if (!ascend_backend_) {
249 status = InsertDequantNode(func_graph, cnode, parameter, idx, tensor_info);
250 } else {
251 status = InsertAscendDequantNode(func_graph, cnode, parameter, idx, tensor_info);
252 }
253 if (status == RET_NO_CHANGE) {
254 continue;
255 } else if (status != RET_OK) {
256 MS_LOG(ERROR) << cnode->fullname_with_scope() << " insert dequan node failed.";
257 return status;
258 }
259 }
260 weight_quantized_tensors_.insert(tensor_info);
261 }
262 return RET_OK;
263 }
264
DoCompression(const CNodePtr & cnode,const ParameterPtr & parameter,const tensor::TensorPtr & tensor)265 int WeightQuantizer::DoCompression(const CNodePtr &cnode, const ParameterPtr ¶meter,
266 const tensor::TensorPtr &tensor) {
267 int ret = RET_OK;
268 auto quantization_params = tensor->quant_params();
269 if (quantization_params.empty()) {
270 MS_LOG(ERROR) << cnode->fullname_with_scope() << " tensor: " << tensor->name() << " quantization params empty.";
271 return RET_ERROR;
272 }
273 auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_params.front());
274 if (dequant_strategy_ == ON_THE_FLY) {
275 if (bit_num_ < k8Bit) {
276 FSEEncoder fse_encoder;
277 mindspore::TensorCompressionType compress_type =
278 dequant_strategy_ == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE;
279 ret = fse_encoder.Compress(parameter, quant_params, compress_type, max_segments_);
280 auto new_tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
281 CHECK_NULL_RETURN(new_tensor_info);
282 weight_quantized_tensors_.insert(new_tensor_info);
283 return ret;
284 } else {
285 return RET_OK;
286 }
287 }
288 TensorCompressor compressor;
289 if (type_id_ == kNumberTypeInt8) {
290 ret = compressor.DoSparseCompress<int8_t>(parameter, bit_num_, quant_params);
291 } else if (type_id_ == kNumberTypeInt16) {
292 ret = compressor.DoSparseCompress<int16_t>(parameter, bit_num_, quant_params);
293 }
294 if (ret != RET_OK) {
295 if (bit_num_ != k8Bit && bit_num_ != k16Bit) {
296 auto status = compressor.DoBitPack(parameter, bit_num_);
297 if (status != RET_OK) {
298 MS_LOG(ERROR) << "do bit pack failed. " << status;
299 return RET_ERROR;
300 }
301 }
302 } else {
303 // compressed tensor is a new tensor.
304 auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
305 CHECK_NULL_RETURN(tensor_info);
306 weight_quantized_tensors_.insert(tensor_info);
307 MS_LOG(INFO) << parameter->fullname_with_scope() << " compression success.";
308 }
309 return RET_OK;
310 }
311
DoMixBitQuant(const CNodePtr & cnode,const ParameterPtr & parameter,int idx,const tensor::TensorPtr & tensor_info,int preferred_dim,WeightQuantType weight_quant_type,bool symmetric)312 int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr ¶meter, int idx,
313 const tensor::TensorPtr &tensor_info, int preferred_dim,
314 WeightQuantType weight_quant_type, bool symmetric) {
315 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
316 CHECK_NULL_RETURN(primitive);
317 auto mixed_bit_quantization = MixedBitWeightQuantization(mixed_bit_init_scale_);
318 auto status = mixed_bit_quantization.QuantFilter(primitive, parameter, tensor_info, quant_type_, is_auto_tune_);
319 if (status == RET_OK) {
320 FSEEncoder fse_encoder;
321 auto quantization_params = tensor_info->quant_params();
322 if (quantization_params.empty()) {
323 MS_LOG(ERROR) << cnode->fullname_with_scope() << " tensor: " << tensor_info->name()
324 << " quantization params empty.";
325 return RET_ERROR;
326 }
327 auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_params.front());
328 mindspore::TensorCompressionType compress_type =
329 dequant_strategy_ == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE;
330 status = fse_encoder.Compress(parameter, quant_params, compress_type);
331 auto new_tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
332 CHECK_NULL_RETURN(new_tensor_info);
333 weight_quantized_tensors_.insert(new_tensor_info);
334 }
335 // rollback to 8 bit.
336 if (status == RET_ERROR || status == RET_NO_CHANGE) {
337 const int quant_min = QuantMin(k8Bit, false, false); // -128
338 const int quant_max = QuantMax(k8Bit); // 127
339 MS_LOG(WARNING)
340 << parameter->fullname_with_scope()
341 << " mixed bit quantization search failed, the current layer rolls back to 8 bit fixed quantization.";
342 FixedBitWeightQuantization fixed_bit_quant;
343 status = fixed_bit_quant.QuantFilter(parameter, tensor_info, primitive, quant_type_, quant_max, quant_min, bit_num_,
344 weight_quant_type, kNumberTypeInt8, preferred_dim, symmetric);
345 }
346 return status;
347 }
348
InsertAscendDequantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const ParameterPtr & parameter,int idx,const tensor::TensorPtr & tensor_info)349 int WeightQuantizer::InsertAscendDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
350 const ParameterPtr ¶meter, int idx,
351 const tensor::TensorPtr &tensor_info) {
352 InsertQuantNodeManager quant_manager;
353 CHECK_NULL_RETURN(func_graph);
354 TypeId type_id;
355 auto tensor_name = parameter->fullname_with_scope();
356 if (opt::GetDataTypeFromAnfNode(cnode, &type_id) != RET_OK) {
357 MS_LOG(WARNING) << cnode->fullname_with_scope() << " Get data type failed.";
358 return RET_NO_CHANGE;
359 }
360 if (parameter->has_default() &&
361 parameter->default_param()->cast<tensor::TensorPtr>()->compression_type() == mindspore::kFSEInfer) {
362 MS_LOG(ERROR) << tensor_name << " is fse encode. It will support in the further.";
363 return RET_ERROR;
364 } else {
365 MS_LOG(INFO) << tensor_name << " insert Ascend AntiQuant node";
366 int axis;
367 // For MOE Linear, the preferred dim is get by the batch_matmul node, which is followed by gather node.
368 if (opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
369 auto support_gather_followed_primitive_types = {prim::kPrimBatchMatMul, prim::kPrimMatMul, prim::kPrimFFN};
370 if (!CheckFollowedNodeInSet(func_graph, cnode, support_gather_followed_primitive_types)) {
371 MS_LOG(INFO) << "In Ascend ON_THE_FLY quant mode, The Gather followed cnode is not BatchMatMul or MatMul, "
372 << cnode->fullname_with_scope() << " dont need weight quant";
373 return RET_OK;
374 }
375 axis = GetFollowedNodePreferredDim(func_graph, cnode, ConvertShapeVectorToInt32(tensor_info->shape()));
376 } else {
377 axis = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
378 }
379 int status;
380 if (type_id == kNumberTypeFloat32) {
381 status = quant_manager.InsertAscendAntiQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat32,
382 axis, param_->chip_name);
383 } else {
384 status = quant_manager.InsertAscendAntiQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat16,
385 axis, param_->chip_name);
386 }
387 if (status != RET_OK) {
388 MS_LOG(ERROR) << tensor_name << " insert weight quant node failed.";
389 return status;
390 }
391 }
392 return RET_OK;
393 }
394
InsertDequantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const ParameterPtr & parameter,int idx,const tensor::TensorPtr & tensor_info)395 int WeightQuantizer::InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
396 const ParameterPtr ¶meter, int idx, const tensor::TensorPtr &tensor_info) {
397 InsertQuantNodeManager quant_manager;
398 CHECK_NULL_RETURN(func_graph);
399 TypeId type_id;
400 int status;
401 auto tensor_name = parameter->fullname_with_scope();
402 if (opt::GetDataTypeFromAnfNode(cnode, &type_id) != RET_OK) {
403 MS_LOG(WARNING) << cnode->fullname_with_scope() << " Get data type failed.";
404 return RET_NO_CHANGE;
405 }
406 if (parameter->has_default() &&
407 parameter->default_param()->cast<tensor::TensorPtr>()->compression_type() == mindspore::kFSEInfer) {
408 MS_LOG(INFO) << tensor_name << " insert FSEDecode node";
409 if (type_id == kNumberTypeFloat32) {
410 status = quant_manager.InsertFSEDecodeNode(func_graph, cnode, idx, kNumberTypeFloat32);
411 } else {
412 status = quant_manager.InsertFSEDecodeNode(func_graph, cnode, idx, kNumberTypeFloat16);
413 }
414 if (status != RET_OK) {
415 MS_LOG(ERROR) << tensor_name << " insert FSEDecode node failed.";
416 return status;
417 }
418 } else {
419 MS_LOG(INFO) << tensor_name << " insert WeightQuant node";
420 auto axis = GetPreferredDim(cnode, idx - kPrimOffset, ConvertShapeVectorToInt32(tensor_info->shape_c()));
421 if (type_id == kNumberTypeFloat32) {
422 status = quant_manager.InsertQuantDtypeCastFlyNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat32,
423 axis, true);
424 } else {
425 status = quant_manager.InsertQuantDtypeCastFlyNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat16,
426 axis, true);
427 }
428 if (status != RET_OK) {
429 MS_LOG(ERROR) << tensor_name << " insert weight quant node failed.";
430 return status;
431 }
432 }
433 return RET_OK;
434 }
435
MarkCNodeWeightQuantType(const CNodePtr & cnode)436 int WeightQuantizer::MarkCNodeWeightQuantType(const CNodePtr &cnode) {
437 CHECK_NULL_RETURN(cnode);
438 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
439 if (primitive == nullptr) {
440 MS_LOG(ERROR) << "primitive is nullptr";
441 return RET_ERROR;
442 }
443
444 auto quant_type_attr = primitive->GetAttr(quant::kQuantType);
445 if (quant_type_attr != nullptr) {
446 auto quant_type = static_cast<quant::QuantType>(GetValue<int32_t>(quant_type_attr));
447 if (quant_type == quant::QUANT_WEIGHT) {
448 // already marked with QUANT_WEIGHT
449 return RET_OK;
450 }
451 }
452
453 // Support Share Weight Quant.
454 for (size_t i = kPrimOffset; i < cnode->size(); i++) {
455 auto input_node = cnode->input(i);
456 if (input_node->isa<Parameter>()) {
457 ParameterPtr param_node;
458 tensor::TensorPtr tensor_info;
459 GetParameterAndTensor(input_node, ¶m_node, &tensor_info);
460 auto param = weight_quantized_tensors_.find(tensor_info);
461 if (param != weight_quantized_tensors_.end()) {
462 primitive->AddAttr(quant::kQuantType, MakeValue(static_cast<int>(quant::QUANT_WEIGHT)));
463 continue;
464 }
465 }
466 }
467 return RET_OK;
468 }
469
MarkGraphWeightQuantType(const FuncGraphPtr & func_graph)470 int WeightQuantizer::MarkGraphWeightQuantType(const FuncGraphPtr &func_graph) {
471 MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
472 for (auto &cnode : func_graph->GetOrderedCnodes()) {
473 auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
474 if (primitive == nullptr) {
475 MS_LOG(DEBUG) << cnode->fullname_with_scope() << " primitive is nullptr";
476 continue;
477 }
478 auto status = MarkCNodeWeightQuantType(cnode);
479 if (status != RET_OK) {
480 MS_LOG(ERROR) << cnode->fullname_with_scope() << " mark graph QuantType failed.";
481 return RET_ERROR;
482 }
483 }
484 return RET_OK;
485 }
486
DoQuantize(FuncGraphPtr func_graph)487 int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
488 CHECK_NULL_RETURN(func_graph);
489 weight_quantized_tensors_.clear();
490 std::set<PrimitivePtr> support_primitive_types;
491 std::set<PrimitivePtr> per_layer_primitive_types;
492 if (ascend_backend_) {
493 support_primitive_types = {prim::kPrimMatMulFusion, prim::kPrimBatchMatMul, prim::kPrimMatMul, prim::kPrimGather,
494 prim::kPrimFFN};
495 if (per_channel_) {
496 per_layer_primitive_types = {};
497 } else {
498 per_layer_primitive_types = {prim::kPrimMatMulFusion, prim::kPrimMatMul, prim::kPrimBatchMatMul,
499 prim::kPrimGather, prim::kPrimFFN};
500 }
501 } else if (param_->weightQuantParam.quant_strategy == quant::GPTQ_ALGORITHM) {
502 support_primitive_types = {prim::kPrimMatMulFusion, prim::kPrimBatchMatMul, prim::kPrimMatMul};
503 } else {
504 support_primitive_types = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
505 prim::kPrimMatMulFusion, prim::kPrimFullConnection,
506 prim::kPrimLstm, prim::kPrimGather,
507 prim::kPrimAdam, prim::kPrimSGD,
508 prim::kPrimApplyMomentum, prim::kPrimConv2D,
509 prim::kPrimMatMul};
510 per_layer_primitive_types = {prim::kPrimAdam, prim::kPrimSGD, prim::kPrimApplyMomentum};
511 }
512 if (param_->weightQuantParam.quant_strategy == quant::GPTQ_ALGORITHM) {
513 std::set<PrimitivePtr> gptq_support_primitive_types = {prim::kPrimMatMulFusion};
514 auto GPTQ = std::make_unique<GptqQuantizer>(func_graph, param_, gptq_support_primitive_types);
515 CHECK_NULL_RETURN(GPTQ);
516 if (GPTQ->DoQuantize() != RET_OK) {
517 MS_LOG(ERROR) << "GPTQ weight quant failed.";
518 return RET_ERROR;
519 }
520 } else {
521 auto ret = WeightQuant(func_graph, support_primitive_types, per_layer_primitive_types, {});
522 if (ret != RET_OK) {
523 MS_LOG(ERROR) << "Weight Quant failed.";
524 return ret;
525 }
526 if (dequant_strategy_ != ON_THE_FLY) {
527 return MarkGraphWeightQuantType(func_graph);
528 }
529 }
530 return RET_OK;
531 }
532 } // namespace mindspore::lite::quant
533