• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 
19 #include "tools/converter/quantizer/weight_quantizer.h"
20 #include <list>
21 #include <string>
22 #include <utility>
23 #include <set>
24 #include "mindspore/core/ops/conv_pool_ops.h"
25 #include "mindspore/core/ops/nn_optimizer_ops.h"
26 #include "mindspore/core/ops/nn_ops.h"
27 #include "mindspore/core/ops/other_ops.h"
28 #include "mindspore/core/ops/math_ops.h"
29 #include "mindspore/core/ops/lite_ops.h"
30 #include "mindspore/core/ops/array_ops.h"
31 #include "tools/optimizer/common/gllo_utils.h"
32 #include "src/common/log_util.h"
33 #include "tools/converter/quantizer/fse_encoder.h"
34 #include "tools/converter/quantizer/tensor_compressor.h"
35 #include "tools/converter/quantizer/cluster_quantization.h"
36 #include "tools/converter/quantizer/mixed_bit_weight_quantization.h"
37 #include "tools/converter/quantizer/fixed_bit_weight_quantization.h"
38 #include "tools/converter/quantizer/insert_quant_node_manager.h"
39 #include "tools/common/node_util.h"
40 #include "src/common/quant_utils.h"
41 #include "tools/converter/quantizer/gptq_quantizer.h"
42 
43 namespace mindspore::lite::quant {
44 namespace {
ConvertParameterFp16TensorToFp32(const ParameterPtr & parameter)45 tensor::TensorPtr ConvertParameterFp16TensorToFp32(const ParameterPtr &parameter) {
46   if (!parameter->has_default()) {
47     MS_LOG(WARNING) << parameter->fullname_with_scope() << " not has_default";
48     return nullptr;
49   }
50   auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
51   if (tensor_info == nullptr) {
52     MS_LOG(WARNING) << "default_param can not cast to tensor::Tensor";
53     return nullptr;
54   }
55   if (tensor_info->data_type() == kNumberTypeFloat16) {
56     MS_LOG(INFO) << "convert " << parameter->fullname_with_scope() << " from fp16 to fp32.";
57     auto data = static_cast<float16 *>(tensor_info->data_c());
58     std::vector<float> fp32_data(tensor_info->DataSize());
59     for (size_t j = 0; j < tensor_info->DataSize(); j++) {
60       fp32_data[j] = mindspore::Float16::ToFloat32(data[j]);
61     }
62     mindspore::tensor::TensorPtr tensor_ptr = std::make_shared<mindspore::tensor::Tensor>(
63       kNumberTypeFloat32, tensor_info->shape_c(), fp32_data.data(), fp32_data.size() * sizeof(float));
64     parameter->set_default_param(tensor_ptr);
65     parameter->set_abstract(tensor_ptr->ToAbstract());
66     return tensor_ptr;
67   }
68   return tensor_info;
69 }
70 }  // namespace
WeightQuant(const FuncGraphPtr & func_graph,const std::set<PrimitivePtr> & support_weight_quant_types,const std::set<PrimitivePtr> & per_layer_types,const std::set<PrimitivePtr> & symmetric_types,bool compression)71 int WeightQuantizer::WeightQuant(const FuncGraphPtr &func_graph,
72                                  const std::set<PrimitivePtr> &support_weight_quant_types,
73                                  const std::set<PrimitivePtr> &per_layer_types,
74                                  const std::set<PrimitivePtr> &symmetric_types, bool compression) {
75   for (auto &cnode : func_graph->GetOrderedCnodes()) {
76     auto ret =
77       WeightQuantPerCNode(func_graph, cnode, support_weight_quant_types, per_layer_types, symmetric_types, compression);
78     if (ret != RET_OK) {
79       MS_LOG(ERROR) << cnode->fullname_with_scope() << " execute weight quantize error.";
80       return RET_ERROR;
81     }
82   }
83   return RET_OK;
84 }
85 
WeightQuantPerCNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::set<PrimitivePtr> & support_weight_quant_types,const std::set<PrimitivePtr> & per_layer_types,const std::set<PrimitivePtr> & symmetric_types,bool compression)86 int WeightQuantizer::WeightQuantPerCNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
87                                          const std::set<PrimitivePtr> &support_weight_quant_types,
88                                          const std::set<PrimitivePtr> &per_layer_types,
89                                          const std::set<PrimitivePtr> &symmetric_types, bool compression) {
90   auto primitive = GetValueNode<std::shared_ptr<Primitive>>(cnode->input(0));
91   if (primitive == nullptr) {
92     MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive is nullptr";
93     return RET_OK;
94   }
95   auto op_name = cnode->fullname_with_scope();
96   if (skip_quant_node_.find(op_name) != skip_quant_node_.end()) {
97     MS_LOG(INFO) << op_name << " is skip dynamic quant.";
98     return RET_OK;
99   }
100   if (!CheckNodeInSet(cnode, support_weight_quant_types)) {
101     MS_LOG(INFO) << cnode->fullname_with_scope() << " of type: " << primitive->name() << " dont need weight quant.";
102     return RET_OK;
103   }
104 
105   // Ascend ON_THE_FLY quant only support Gather followed by BatchMatMul or MatMul.
106   if (ascend_backend_ && dequant_strategy_ == ON_THE_FLY && opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
107     auto support_gather_followed_primitive_types = {prim::kPrimBatchMatMul, prim::kPrimMatMul, prim::kPrimFFN};
108     if (!CheckFollowedNodeInSet(func_graph, cnode, support_gather_followed_primitive_types)) {
109       MS_LOG(INFO) << "In Ascend ON_THE_FLY quant mode, The Gather followed cnode is not BatchMatMul or MatMul, "
110                    << cnode->fullname_with_scope() << " dont need weight quant";
111       return RET_OK;
112     }
113   }
114 
115   // Init weight quant index.
116   std::vector<int> weight_indices;
117   if (opt::CheckPrimitiveType(cnode, prim::kPrimAdam)) {
118     weight_indices = {2, 3};
119   } else if (opt::CheckPrimitiveType(cnode, prim::kPrimSGD)) {
120     weight_indices = {4, 6};
121   } else if (opt::CheckPrimitiveType(cnode, prim::kPrimApplyMomentum)) {
122     weight_indices = {2};
123   } else {
124     for (size_t i = 1; i < cnode->size(); ++i) {
125       weight_indices.push_back(i);
126     }
127   }
128 
129   if (linear_quant_) {
130     bool is_compression = compression && !is_mixed_bit_ && enable_encode_;
131     auto ret = LinearQuant(func_graph, cnode, per_layer_types, symmetric_types, weight_indices, is_compression);
132     if (ret != RET_OK) {
133       MS_LOG(ERROR) << cnode->fullname_with_scope() << " execute linear weight quantize error.";
134       return RET_ERROR;
135     }
136   } else {
137     ClusterQuantization cluster;
138     auto ret = cluster.KMeansQuantization(cnode, weight_indices);
139     if (ret != RET_OK) {
140       MS_LOG(ERROR) << cnode->fullname_with_scope() << " execute k-means weight quantize error.";
141       return RET_ERROR;
142     }
143   }
144   return RET_OK;
145 }
146 
PreLinearQuant(const CNodePtr & cnode,int idx,const AnfNodePtr & input,ParameterPtr * parameter,tensor::TensorPtr * tensor_info)147 int WeightQuantizer::PreLinearQuant(const CNodePtr &cnode, int idx, const AnfNodePtr &input, ParameterPtr *parameter,
148                                     tensor::TensorPtr *tensor_info) {
149   CHECK_NULL_RETURN(parameter);
150   CHECK_NULL_RETURN(tensor_info);
151   GetParameterAndTensor(input, parameter, tensor_info);
152   if (*parameter == nullptr || *tensor_info == nullptr ||
153       (*tensor_info)->compression_type() != mindspore::kNoCompression) {
154     MS_LOG(INFO) << "This op " << input->fullname_with_scope() << " dont need quant weight";
155     return RET_NO_CHANGE;
156   }
157   *tensor_info = ConvertParameterFp16TensorToFp32(*parameter);
158   if ((*tensor_info) == nullptr || (*tensor_info)->data_type() != TypeId::kNumberTypeFloat32) {
159     MS_LOG(INFO) << "This op " << input->fullname_with_scope() << " is null or dtype is not fp32.";
160     return RET_NO_CHANGE;
161   }
162   auto preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32((*tensor_info)->shape()));
163   if (quant_strategy_ != nullptr && !quant_strategy_->CanTensorQuantized(cnode, input, preferred_dim)) {
164     MS_LOG(INFO) << input->fullname_with_scope() << " will not quantify";
165     return RET_NO_CHANGE;
166   }
167   return RET_OK;
168 }
169 
LinearQuant(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const std::set<PrimitivePtr> & per_layer_types,const std::set<PrimitivePtr> & symmetric_types,const std::vector<int> & weight_indices,bool compression)170 int WeightQuantizer::LinearQuant(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
171                                  const std::set<PrimitivePtr> &per_layer_types,
172                                  const std::set<PrimitivePtr> &symmetric_types, const std::vector<int> &weight_indices,
173                                  bool compression) {
174   CHECK_NULL_RETURN(cnode);
175   // Avoid affecting other operators
176   auto tmp_weight_quant_type = weight_quant_type_;
177   if (CheckNodeInSet(cnode, per_layer_types)) {
178     tmp_weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
179   }
180   bool symmetric = false;
181   int q_min = quant_min_;
182   int q_max = quant_max_;
183   if (CheckNodeInSet(cnode, symmetric_types)) {
184     symmetric = true;
185     q_min = symmetric_quant_min_;
186     q_max = symmetric_quant_max_;
187   }
188 
189   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
190   CHECK_NULL_RETURN(primitive);
191   auto manager = mindspore::Manage(func_graph, true);
192   CHECK_NULL_RETURN(manager);
193   for (auto idx : weight_indices) {
194     auto input = cnode->input(idx);
195     ParameterPtr parameter;
196     tensor::TensorPtr tensor_info;
197     auto status = PreLinearQuant(cnode, idx, input, &parameter, &tensor_info);
198     if (status == RET_NO_CHANGE) {
199       continue;
200     } else if (status != RET_OK) {
201       MS_LOG(ERROR) << input->fullname_with_scope() << " pre linear quant failed : " << status;
202       return status;
203     }
204     // support for matmul shared weight
205     auto node_map = manager->node_users();
206     auto node_user = node_map[input];
207     if (node_user.size() > 1 && opt::CheckPrimitiveType(cnode, prim::kPrimMatMulFusion)) {
208       MS_LOG(INFO) << input->fullname_with_scope() << " is shared weight.";
209       tmp_weight_quant_type = WeightQuantType::FIXED_BIT_PER_LAYER;
210     }
211     // linear quant
212     int preferred_dim;
213     // For MOE Linear, the preferred dim is get by the batch_matmul node, which is followed by gather node.
214     if (ascend_backend_ && dequant_strategy_ == ON_THE_FLY && opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
215       auto support_gather_followed_primitive_types = {prim::kPrimBatchMatMul, prim::kPrimMatMul, prim::kPrimFFN};
216       if (!CheckFollowedNodeInSet(func_graph, cnode, support_gather_followed_primitive_types)) {
217         MS_LOG(INFO) << "In Ascend ON_THE_FLY quant mode, The Gather followed cnode is not BatchMatMul or MatMul, "
218                      << cnode->fullname_with_scope() << " dont need weight quant";
219         return RET_OK;
220       }
221       preferred_dim = GetFollowedNodePreferredDim(func_graph, cnode, ConvertShapeVectorToInt32(tensor_info->shape()));
222     } else {
223       preferred_dim = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
224     }
225     if (is_mixed_bit_) {
226       status = DoMixBitQuant(cnode, parameter, idx, tensor_info, preferred_dim, tmp_weight_quant_type, symmetric);
227     } else {
228       FixedBitWeightQuantization fixed_bit_quant;
229       status =
230         fixed_bit_quant.QuantFilter(parameter, tensor_info, primitive, quant_type_, q_max, q_min, bit_num_,
231                                     tmp_weight_quant_type, type_id_, preferred_dim, symmetric, false, bias_correction_);
232     }
233     if (status == RET_NO_CHANGE) {
234       continue;
235     } else if (status != RET_OK) {
236       MS_LOG(ERROR) << "QuantFilter failed : " << status;
237       return status;
238     }
239     // Post linear quant
240     if (compression) {
241       status = DoCompression(cnode, parameter, tensor_info);
242       if (status != RET_OK) {
243         MS_LOG(ERROR) << cnode->fullname_with_scope() << " compression failed.";
244         return status;
245       }
246     }
247     if (dequant_strategy_ == ON_THE_FLY) {
248       if (!ascend_backend_) {
249         status = InsertDequantNode(func_graph, cnode, parameter, idx, tensor_info);
250       } else {
251         status = InsertAscendDequantNode(func_graph, cnode, parameter, idx, tensor_info);
252       }
253       if (status == RET_NO_CHANGE) {
254         continue;
255       } else if (status != RET_OK) {
256         MS_LOG(ERROR) << cnode->fullname_with_scope() << " insert dequan node failed.";
257         return status;
258       }
259     }
260     weight_quantized_tensors_.insert(tensor_info);
261   }
262   return RET_OK;
263 }
264 
DoCompression(const CNodePtr & cnode,const ParameterPtr & parameter,const tensor::TensorPtr & tensor)265 int WeightQuantizer::DoCompression(const CNodePtr &cnode, const ParameterPtr &parameter,
266                                    const tensor::TensorPtr &tensor) {
267   int ret = RET_OK;
268   auto quantization_params = tensor->quant_params();
269   if (quantization_params.empty()) {
270     MS_LOG(ERROR) << cnode->fullname_with_scope() << " tensor: " << tensor->name() << " quantization params empty.";
271     return RET_ERROR;
272   }
273   auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_params.front());
274   if (dequant_strategy_ == ON_THE_FLY) {
275     if (bit_num_ < k8Bit) {
276       FSEEncoder fse_encoder;
277       mindspore::TensorCompressionType compress_type =
278         dequant_strategy_ == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE;
279       ret = fse_encoder.Compress(parameter, quant_params, compress_type, max_segments_);
280       auto new_tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
281       CHECK_NULL_RETURN(new_tensor_info);
282       weight_quantized_tensors_.insert(new_tensor_info);
283       return ret;
284     } else {
285       return RET_OK;
286     }
287   }
288   TensorCompressor compressor;
289   if (type_id_ == kNumberTypeInt8) {
290     ret = compressor.DoSparseCompress<int8_t>(parameter, bit_num_, quant_params);
291   } else if (type_id_ == kNumberTypeInt16) {
292     ret = compressor.DoSparseCompress<int16_t>(parameter, bit_num_, quant_params);
293   }
294   if (ret != RET_OK) {
295     if (bit_num_ != k8Bit && bit_num_ != k16Bit) {
296       auto status = compressor.DoBitPack(parameter, bit_num_);
297       if (status != RET_OK) {
298         MS_LOG(ERROR) << "do bit pack failed. " << status;
299         return RET_ERROR;
300       }
301     }
302   } else {
303     // compressed tensor is a new tensor.
304     auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
305     CHECK_NULL_RETURN(tensor_info);
306     weight_quantized_tensors_.insert(tensor_info);
307     MS_LOG(INFO) << parameter->fullname_with_scope() << " compression success.";
308   }
309   return RET_OK;
310 }
311 
DoMixBitQuant(const CNodePtr & cnode,const ParameterPtr & parameter,int idx,const tensor::TensorPtr & tensor_info,int preferred_dim,WeightQuantType weight_quant_type,bool symmetric)312 int WeightQuantizer::DoMixBitQuant(const CNodePtr &cnode, const ParameterPtr &parameter, int idx,
313                                    const tensor::TensorPtr &tensor_info, int preferred_dim,
314                                    WeightQuantType weight_quant_type, bool symmetric) {
315   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
316   CHECK_NULL_RETURN(primitive);
317   auto mixed_bit_quantization = MixedBitWeightQuantization(mixed_bit_init_scale_);
318   auto status = mixed_bit_quantization.QuantFilter(primitive, parameter, tensor_info, quant_type_, is_auto_tune_);
319   if (status == RET_OK) {
320     FSEEncoder fse_encoder;
321     auto quantization_params = tensor_info->quant_params();
322     if (quantization_params.empty()) {
323       MS_LOG(ERROR) << cnode->fullname_with_scope() << " tensor: " << tensor_info->name()
324                     << " quantization params empty.";
325       return RET_ERROR;
326     }
327     auto quant_params = quant::ConvertQuantizationParamToQuantParamT(quantization_params.front());
328     mindspore::TensorCompressionType compress_type =
329       dequant_strategy_ == ON_THE_FLY ? mindspore::kFSEInfer : mindspore::kFSE;
330     status = fse_encoder.Compress(parameter, quant_params, compress_type);
331     auto new_tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
332     CHECK_NULL_RETURN(new_tensor_info);
333     weight_quantized_tensors_.insert(new_tensor_info);
334   }
335   // rollback to 8 bit.
336   if (status == RET_ERROR || status == RET_NO_CHANGE) {
337     const int quant_min = QuantMin(k8Bit, false, false);  // -128
338     const int quant_max = QuantMax(k8Bit);                // 127
339     MS_LOG(WARNING)
340       << parameter->fullname_with_scope()
341       << " mixed bit quantization search failed, the current layer rolls back to 8 bit fixed quantization.";
342     FixedBitWeightQuantization fixed_bit_quant;
343     status = fixed_bit_quant.QuantFilter(parameter, tensor_info, primitive, quant_type_, quant_max, quant_min, bit_num_,
344                                          weight_quant_type, kNumberTypeInt8, preferred_dim, symmetric);
345   }
346   return status;
347 }
348 
InsertAscendDequantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const ParameterPtr & parameter,int idx,const tensor::TensorPtr & tensor_info)349 int WeightQuantizer::InsertAscendDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
350                                              const ParameterPtr &parameter, int idx,
351                                              const tensor::TensorPtr &tensor_info) {
352   InsertQuantNodeManager quant_manager;
353   CHECK_NULL_RETURN(func_graph);
354   TypeId type_id;
355   auto tensor_name = parameter->fullname_with_scope();
356   if (opt::GetDataTypeFromAnfNode(cnode, &type_id) != RET_OK) {
357     MS_LOG(WARNING) << cnode->fullname_with_scope() << " Get data type failed.";
358     return RET_NO_CHANGE;
359   }
360   if (parameter->has_default() &&
361       parameter->default_param()->cast<tensor::TensorPtr>()->compression_type() == mindspore::kFSEInfer) {
362     MS_LOG(ERROR) << tensor_name << " is fse encode. It will support in the further.";
363     return RET_ERROR;
364   } else {
365     MS_LOG(INFO) << tensor_name << " insert Ascend AntiQuant node";
366     int axis;
367     // For MOE Linear, the preferred dim is get by the batch_matmul node, which is followed by gather node.
368     if (opt::CheckPrimitiveType(cnode, prim::kPrimGather)) {
369       auto support_gather_followed_primitive_types = {prim::kPrimBatchMatMul, prim::kPrimMatMul, prim::kPrimFFN};
370       if (!CheckFollowedNodeInSet(func_graph, cnode, support_gather_followed_primitive_types)) {
371         MS_LOG(INFO) << "In Ascend ON_THE_FLY quant mode, The Gather followed cnode is not BatchMatMul or MatMul, "
372                      << cnode->fullname_with_scope() << " dont need weight quant";
373         return RET_OK;
374       }
375       axis = GetFollowedNodePreferredDim(func_graph, cnode, ConvertShapeVectorToInt32(tensor_info->shape()));
376     } else {
377       axis = GetPreferredDim(cnode, idx - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
378     }
379     int status;
380     if (type_id == kNumberTypeFloat32) {
381       status = quant_manager.InsertAscendAntiQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat32,
382                                                        axis, param_->chip_name);
383     } else {
384       status = quant_manager.InsertAscendAntiQuantNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat16,
385                                                        axis, param_->chip_name);
386     }
387     if (status != RET_OK) {
388       MS_LOG(ERROR) << tensor_name << " insert weight quant node failed.";
389       return status;
390     }
391   }
392   return RET_OK;
393 }
394 
InsertDequantNode(const FuncGraphPtr & func_graph,const CNodePtr & cnode,const ParameterPtr & parameter,int idx,const tensor::TensorPtr & tensor_info)395 int WeightQuantizer::InsertDequantNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode,
396                                        const ParameterPtr &parameter, int idx, const tensor::TensorPtr &tensor_info) {
397   InsertQuantNodeManager quant_manager;
398   CHECK_NULL_RETURN(func_graph);
399   TypeId type_id;
400   int status;
401   auto tensor_name = parameter->fullname_with_scope();
402   if (opt::GetDataTypeFromAnfNode(cnode, &type_id) != RET_OK) {
403     MS_LOG(WARNING) << cnode->fullname_with_scope() << " Get data type failed.";
404     return RET_NO_CHANGE;
405   }
406   if (parameter->has_default() &&
407       parameter->default_param()->cast<tensor::TensorPtr>()->compression_type() == mindspore::kFSEInfer) {
408     MS_LOG(INFO) << tensor_name << " insert FSEDecode node";
409     if (type_id == kNumberTypeFloat32) {
410       status = quant_manager.InsertFSEDecodeNode(func_graph, cnode, idx, kNumberTypeFloat32);
411     } else {
412       status = quant_manager.InsertFSEDecodeNode(func_graph, cnode, idx, kNumberTypeFloat16);
413     }
414     if (status != RET_OK) {
415       MS_LOG(ERROR) << tensor_name << " insert FSEDecode node failed.";
416       return status;
417     }
418   } else {
419     MS_LOG(INFO) << tensor_name << " insert WeightQuant node";
420     auto axis = GetPreferredDim(cnode, idx - kPrimOffset, ConvertShapeVectorToInt32(tensor_info->shape_c()));
421     if (type_id == kNumberTypeFloat32) {
422       status = quant_manager.InsertQuantDtypeCastFlyNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat32,
423                                                          axis, true);
424     } else {
425       status = quant_manager.InsertQuantDtypeCastFlyNode(func_graph, cnode, idx, kNumberTypeInt8, kNumberTypeFloat16,
426                                                          axis, true);
427     }
428     if (status != RET_OK) {
429       MS_LOG(ERROR) << tensor_name << " insert weight quant node failed.";
430       return status;
431     }
432   }
433   return RET_OK;
434 }
435 
MarkCNodeWeightQuantType(const CNodePtr & cnode)436 int WeightQuantizer::MarkCNodeWeightQuantType(const CNodePtr &cnode) {
437   CHECK_NULL_RETURN(cnode);
438   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
439   if (primitive == nullptr) {
440     MS_LOG(ERROR) << "primitive is nullptr";
441     return RET_ERROR;
442   }
443 
444   auto quant_type_attr = primitive->GetAttr(quant::kQuantType);
445   if (quant_type_attr != nullptr) {
446     auto quant_type = static_cast<quant::QuantType>(GetValue<int32_t>(quant_type_attr));
447     if (quant_type == quant::QUANT_WEIGHT) {
448       // already marked with QUANT_WEIGHT
449       return RET_OK;
450     }
451   }
452 
453   // Support Share Weight Quant.
454   for (size_t i = kPrimOffset; i < cnode->size(); i++) {
455     auto input_node = cnode->input(i);
456     if (input_node->isa<Parameter>()) {
457       ParameterPtr param_node;
458       tensor::TensorPtr tensor_info;
459       GetParameterAndTensor(input_node, &param_node, &tensor_info);
460       auto param = weight_quantized_tensors_.find(tensor_info);
461       if (param != weight_quantized_tensors_.end()) {
462         primitive->AddAttr(quant::kQuantType, MakeValue(static_cast<int>(quant::QUANT_WEIGHT)));
463         continue;
464       }
465     }
466   }
467   return RET_OK;
468 }
469 
MarkGraphWeightQuantType(const FuncGraphPtr & func_graph)470 int WeightQuantizer::MarkGraphWeightQuantType(const FuncGraphPtr &func_graph) {
471   MS_CHECK_TRUE_RET(func_graph != nullptr, RET_NULL_PTR);
472   for (auto &cnode : func_graph->GetOrderedCnodes()) {
473     auto primitive = GetValueNode<std::shared_ptr<ops::PrimitiveC>>(cnode->input(0));
474     if (primitive == nullptr) {
475       MS_LOG(DEBUG) << cnode->fullname_with_scope() << " primitive is nullptr";
476       continue;
477     }
478     auto status = MarkCNodeWeightQuantType(cnode);
479     if (status != RET_OK) {
480       MS_LOG(ERROR) << cnode->fullname_with_scope() << " mark graph QuantType failed.";
481       return RET_ERROR;
482     }
483   }
484   return RET_OK;
485 }
486 
DoQuantize(FuncGraphPtr func_graph)487 int WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
488   CHECK_NULL_RETURN(func_graph);
489   weight_quantized_tensors_.clear();
490   std::set<PrimitivePtr> support_primitive_types;
491   std::set<PrimitivePtr> per_layer_primitive_types;
492   if (ascend_backend_) {
493     support_primitive_types = {prim::kPrimMatMulFusion, prim::kPrimBatchMatMul, prim::kPrimMatMul, prim::kPrimGather,
494                                prim::kPrimFFN};
495     if (per_channel_) {
496       per_layer_primitive_types = {};
497     } else {
498       per_layer_primitive_types = {prim::kPrimMatMulFusion, prim::kPrimMatMul, prim::kPrimBatchMatMul,
499                                    prim::kPrimGather, prim::kPrimFFN};
500     }
501   } else if (param_->weightQuantParam.quant_strategy == quant::GPTQ_ALGORITHM) {
502     support_primitive_types = {prim::kPrimMatMulFusion, prim::kPrimBatchMatMul, prim::kPrimMatMul};
503   } else {
504     support_primitive_types = {prim::kPrimConv2DFusion,  prim::kPrimConv2dTransposeFusion,
505                                prim::kPrimMatMulFusion,  prim::kPrimFullConnection,
506                                prim::kPrimLstm,          prim::kPrimGather,
507                                prim::kPrimAdam,          prim::kPrimSGD,
508                                prim::kPrimApplyMomentum, prim::kPrimConv2D,
509                                prim::kPrimMatMul};
510     per_layer_primitive_types = {prim::kPrimAdam, prim::kPrimSGD, prim::kPrimApplyMomentum};
511   }
512   if (param_->weightQuantParam.quant_strategy == quant::GPTQ_ALGORITHM) {
513     std::set<PrimitivePtr> gptq_support_primitive_types = {prim::kPrimMatMulFusion};
514     auto GPTQ = std::make_unique<GptqQuantizer>(func_graph, param_, gptq_support_primitive_types);
515     CHECK_NULL_RETURN(GPTQ);
516     if (GPTQ->DoQuantize() != RET_OK) {
517       MS_LOG(ERROR) << "GPTQ weight quant failed.";
518       return RET_ERROR;
519     }
520   } else {
521     auto ret = WeightQuant(func_graph, support_primitive_types, per_layer_primitive_types, {});
522     if (ret != RET_OK) {
523       MS_LOG(ERROR) << "Weight Quant failed.";
524       return ret;
525     }
526     if (dequant_strategy_ != ON_THE_FLY) {
527       return MarkGraphWeightQuantType(func_graph);
528     }
529   }
530   return RET_OK;
531 }
532 }  // namespace mindspore::lite::quant
533