• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 
19 #include "tools/converter/quantizer/full_quant_quantizer.h"
20 #include <dirent.h>
21 #include <memory>
22 #include <unordered_map>
23 #include <string>
24 #include <vector>
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "mindspore/core/ops/nn_ops.h"
27 #include "mindspore/core/ops/math_ops.h"
28 #include "mindspore/core/ops/lite_ops.h"
29 #include "mindspore/core/ops/array_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "ops/tuple_get_item.h"
32 #include "src/tensor.h"
33 #include "tools/converter/quantizer/insert_quant_node_manager.h"
34 #include "tools/converter/quantizer/quantize_util.h"
35 #include "tools/optimizer/common/gllo_utils.h"
36 #include "src/common/log_adapter.h"
37 #include "tools/common/tensor_util.h"
38 #include "src/common/utils.h"
39 #include "tools/common/node_util.h"
40 #include "nnacl/op_base.h"
41 #include "src/common/log_util.h"
42 #include "tools/converter/quantizer/bias_correction_strategy.h"
43 
44 using std::string;
45 using std::vector;
46 
47 namespace mindspore::lite::quant {
~FullQuantQuantizer()48 FullQuantQuantizer::~FullQuantQuantizer() {}
49 
GetQuantParam(const AnfNodePtr & input_node,const std::unique_ptr<DataDistribution> & info) const50 std::vector<schema::QuantParamT> FullQuantQuantizer::GetQuantParam(
51   const AnfNodePtr &input_node, const std::unique_ptr<DataDistribution> &info) const {
52   std::vector<schema::QuantParamT> quant_params;
53   schema::QuantParamT quant_param;
54   TypeId type_id = kTypeUnknown;
55   if (opt::GetDataTypeFromAnfNode(input_node, &type_id) != RET_OK) {
56     MS_LOG(ERROR) << "Get data type failed.";
57     return quant_params;
58   }
59   if (type_id == kNumberTypeFloat32 && info != nullptr) {
60     quant_param.scale = info->GetScale();
61     quant_param.zeroPoint = info->GetZeroPoint();
62     quant_param.max = info->GetEncodeMax();
63     quant_param.min = info->GetEncodeMin();
64     quant_param.numBits = init_param_.bit_num_;
65     quant_param.narrowRange = true;
66     quant_param.inited = true;
67     quant_param.roundType = 1;
68     quant_param.multiplier = 1;
69     quant_params.push_back(quant_param);
70   }
71   return quant_params;
72 }
73 
QuantWeight(const CNodePtr & cnode,const PrimitivePtr & primitive,const AnfNodePtr & weight,int input_index,const tensor::TensorPtr & tensor_info,bool per_channel)74 int FullQuantQuantizer::QuantWeight(const CNodePtr &cnode, const PrimitivePtr &primitive, const AnfNodePtr &weight,
75                                     int input_index, const tensor::TensorPtr &tensor_info, bool per_channel) {
76   int preferred_dim = GetPreferredDim(cnode, input_index - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
77   auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
78   auto weight_q_min = per_channel ? init_param_.weight_channel_q_min_ : init_param_.weight_layer_q_min_;
79   auto weight_q_max = per_channel ? init_param_.weight_channel_q_max_ : init_param_.weight_layer_q_max_;
80   auto symmetric = per_channel ? init_param_.weight_channel_symmetric_ : init_param_.weight_layer_symmetric_;
81   return fixed_bit_quant_.QuantFilter(weight, tensor_info, primitive, quant::QUANT_ALL, weight_q_max, weight_q_min,
82                                       init_param_.bit_num_, weight_quant_type, kNumberTypeInt8, preferred_dim,
83                                       symmetric);
84 }
85 
DoParameterWeightQuant(const CNodePtr & cnode,const ParameterPtr & weight,const PrimitivePtr & primitive,int input_index,bool per_channel)86 int FullQuantQuantizer::DoParameterWeightQuant(const CNodePtr &cnode, const ParameterPtr &weight,
87                                                const PrimitivePtr &primitive, int input_index, bool per_channel) {
88   CHECK_NULL_RETURN(cnode);
89   CHECK_NULL_RETURN(weight);
90   CHECK_NULL_RETURN(primitive);
91   auto tensor_info = weight->default_param()->cast<tensor::TensorPtr>();
92   if (tensor_info == nullptr) {
93     MS_LOG(ERROR) << weight->fullname_with_scope() << " can't get value";
94     return RET_NULL_PTR;
95   }
96   return QuantWeight(cnode, primitive, weight, input_index, tensor_info, per_channel);
97 }
98 
DoValueNodeWeightQuant(const CNodePtr & cnode,const ValueNodePtr & weight,const PrimitivePtr & primitive,int input_index,bool per_channel)99 int FullQuantQuantizer::DoValueNodeWeightQuant(const CNodePtr &cnode, const ValueNodePtr &weight,
100                                                const PrimitivePtr &primitive, int input_index, bool per_channel) {
101   CHECK_NULL_RETURN(weight);
102   CHECK_NULL_RETURN(primitive);
103   auto tensor_info = weight->value()->cast<tensor::TensorPtr>();
104   if (tensor_info == nullptr) {
105     MS_LOG(ERROR) << weight->fullname_with_scope() << " can't get value";
106     return RET_NULL_PTR;
107   }
108   return QuantWeight(cnode, primitive, weight, input_index, tensor_info, per_channel);
109 }
110 
IsSupportWeightQuant(const AnfNodePtr & input_node)111 int FullQuantQuantizer::IsSupportWeightQuant(const AnfNodePtr &input_node) {
112   TypeId type_id = kTypeUnknown;
113   if (opt::GetDataTypeFromAnfNode(input_node, &type_id) != RET_OK) {
114     MS_LOG(ERROR) << "Get data type failed.";
115     return RET_ERROR;
116   }
117   // support for share weight.
118   if (type_id == kNumberTypeInt8) {
119     auto iter = weight_quant_params_bak_.find(input_node->fullname_with_scope());
120     if (iter == weight_quant_params_bak_.end()) {
121       return RET_ERROR;
122     } else {
123       return RET_NO_CHANGE;
124     }
125   }
126   // Only data the data type is fp32 can be quant.
127   if (type_id != kNumberTypeFloat32) {
128     return RET_NO_CHANGE;
129   }
130   return RET_OK;
131 }
132 
DoParameterNodeQuant(const CNodePtr & cnode,const ParameterPtr & input_node,size_t input_index)133 int FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const ParameterPtr &input_node,
134                                              size_t input_index) {
135   auto ret = IsSupportWeightQuant(input_node);
136   if (ret != RET_OK) {
137     return ret;
138   }
139   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
140   CHECK_NULL_RETURN(primitive);
141   auto op_name = cnode->fullname_with_scope();
142   if (input_index == THIRD_INPUT + kPrimOffset && CheckNodeInSet(cnode, kHasBiasOperator)) {
143     auto weight_parameter = cnode->input(SECOND_INPUT + kPrimOffset)->cast<ParameterPtr>();
144     auto active_quant_params = quant::GetInputNodeQuantParam(cnode, FIRST_INPUT + kPrimOffset);
145     ret = fixed_bit_quant_.QuantBias(weight_parameter, input_node, active_quant_params);
146     if (ret != RET_OK) {
147       MS_LOG(ERROR) << op_name << " Do bias quant failed.";
148       return ret;
149     }
150   } else if (param_->fullQuantParam.per_channel && CheckNodeInSet(cnode, per_channel_ops_)) {
151     ret = DoParameterWeightQuant(cnode, input_node, primitive, input_index, true);
152     if (ret != RET_OK) {
153       MS_LOG(ERROR) << op_name << " Do bias quant failed.";
154       return ret;
155     }
156   } else {
157     ret = DoParameterWeightQuant(cnode, input_node, primitive, input_index, false);
158     if (ret != RET_OK) {
159       MS_LOG(ERROR) << op_name << " Do bias quant failed.";
160       return ret;
161     }
162   }
163   // support shared weight
164   auto tensor_info = input_node->default_param()->cast<tensor::TensorPtr>();
165   if (tensor_info->quant_params().empty()) {
166     return RET_NO_CHANGE;
167   }
168   auto quant_params = quant::ConvertQuantizationParamToQuantParamT(tensor_info->quant_params().front());
169   weight_quant_params_bak_[input_node->fullname_with_scope()] = quant_params;
170   return RET_OK;
171 }
172 
DoValueNodeQuant(const CNodePtr & cnode,const ValueNodePtr & input_node,size_t input_index)173 int FullQuantQuantizer::DoValueNodeQuant(const CNodePtr &cnode, const ValueNodePtr &input_node, size_t input_index) {
174   auto ret = IsSupportWeightQuant(input_node);
175   if (ret != RET_OK) {
176     return ret;
177   }
178   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
179   CHECK_NULL_RETURN(primitive);
180   auto op_name = cnode->fullname_with_scope();
181   ret = DoValueNodeWeightQuant(cnode, input_node, primitive, input_index, false);
182   if (ret != RET_OK) {
183     MS_LOG(ERROR) << op_name << " Do value node weight quant failed.";
184     return ret;
185   }
186   return RET_OK;
187 }
188 
QuantNodeGraphInput(const PrimitivePtr & primitive,const AnfNodePtr & input_node,const std::unique_ptr<DataDistribution> & info)189 int FullQuantQuantizer::QuantNodeGraphInput(const PrimitivePtr &primitive, const AnfNodePtr &input_node,
190                                             const std::unique_ptr<DataDistribution> &info) {
191   TypeId type_id = kTypeUnknown;
192   if (opt::GetDataTypeFromAnfNode(input_node, &type_id) != RET_OK) {
193     MS_LOG(ERROR) << "Get data type failed.";
194     return RET_ERROR;
195   }
196   if (type_id == kNumberTypeFloat32 && info != nullptr) {
197     schema::QuantParamT quant_param;
198     quant_param.scale = info->GetScale();
199     quant_param.zeroPoint = info->GetZeroPoint();
200     quant_param.max = info->GetEncodeMax();
201     quant_param.min = info->GetEncodeMin();
202     quant_param.numBits = static_cast<int32_t>(init_param_.bit_num_);
203     quant_param.narrowRange = true;
204     quant_param.inited = true;
205     quant_param.roundType = 1;
206     quant_param.multiplier = 1;
207     auto quantization_param = quant::ConvertQuantParamTToQuantizationParam({quant_param});
208     primitive->AddAttr(quant::kGraphInputQuantParam, quantization_param);
209   }
210   return RET_OK;
211 }
212 
QuantNodeCNode(const CNodePtr & cnode,const AnfNodePtr & input_node,const std::unique_ptr<DataDistribution> & info)213 int FullQuantQuantizer::QuantNodeCNode(const CNodePtr &cnode, const AnfNodePtr &input_node,
214                                        const std::unique_ptr<DataDistribution> &info) {
215   auto input_cnode = input_node->cast<mindspore::CNodePtr>();
216   MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "input_cnode is nullptr.");
217   auto input_cnode_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
218   CHECK_NULL_RETURN(input_cnode_primitive);
219   if (input_cnode_primitive->HasAttr(quant::kQuantParam)) {
220     MS_LOG(INFO) << input_node->fullname_with_scope() << " quant param already exist.";
221     return RET_NO_CHANGE;
222   }
223   auto quant_params = GetQuantParam(input_node, info);
224   if (quant_params.empty()) {
225     MS_LOG(INFO) << input_node->fullname_with_scope() << " quant param not exist.";
226     return RET_NO_CHANGE;
227   }
228   auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(quant_params);
229   if (quantization_ptr != nullptr) {
230     std::vector<ValuePtr> quantization_list = {quantization_ptr};
231     input_cnode_primitive->AddAttr(quant::kQuantParam, std::make_shared<ValueList>(quantization_list));
232   }
233   return RET_OK;
234 }
235 
QuantValueNode(const CNodePtr & cnode,const AnfNodePtr & input_node,size_t i)236 int FullQuantQuantizer::QuantValueNode(const CNodePtr &cnode, const AnfNodePtr &input_node, size_t i) {
237   if (init_param_.weight_data_type_ == kTypeUnknown) {
238     MS_LOG(INFO) << "weight parameters do not need to be quantified.";
239     return RET_NO_CHANGE;
240   }
241   auto value_node = input_node->cast<ValueNodePtr>();
242   auto ret = DoValueNodeQuant(cnode, value_node, i);
243   if (ret == RET_NO_CHANGE) {
244     return RET_NO_CHANGE;
245   } else if (ret != RET_OK) {
246     MS_LOG(ERROR) << input_node->fullname_with_scope() << " Do value node quant failed.";
247     return ret;
248   }
249   // support shared weight
250   auto tensor_info = value_node->value()->cast<tensor::TensorPtr>();
251   if (tensor_info->quant_params().empty()) {
252     return RET_NO_CHANGE;
253   }
254   auto quant_params = quant::ConvertQuantizationParamToQuantParamT(tensor_info->quant_params().front());
255   weight_quant_params_bak_[input_node->fullname_with_scope()] = quant_params;
256   return RET_OK;
257 }
258 
QuantNodeSimpleOp(const CNodePtr & cnode)259 int FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
260   MS_ASSERT(cnode != nullptr);
261   auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
262   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
263   CHECK_NULL_RETURN(primitive);
264   auto op_name = cnode->fullname_with_scope();
265   MS_ASSERT(cnode->size() - 1 <= (*inputs_diverg_info)[op_name].size());
266   int ret;
267   for (size_t i = 1; i < cnode->size(); i++) {
268     auto input_node = cnode->input(i);
269     CHECK_NULL_RETURN(input_node);
270     bool is_graph_input = IsGraphInput(input_node);
271     if (is_graph_input) {
272       // do input quant
273       auto &info = (*inputs_diverg_info)[op_name][i - 1];
274       if (info == nullptr) {
275         MS_LOG(INFO) << input_node->fullname_with_scope() << " quant info not exist.";
276         continue;
277       }
278       if (QuantNodeGraphInput(primitive, input_node, info) != RET_OK) {
279         MS_LOG(ERROR) << input_node->fullname_with_scope() << " Do graph input node quant failed.";
280         return RET_ERROR;
281       }
282     } else if (input_node->isa<mindspore::CNode>()) {
283       auto &info = (*inputs_diverg_info)[op_name][i - 1];
284       ret = QuantNodeCNode(cnode, input_node, info);
285       if (ret != RET_NO_CHANGE && ret != RET_OK) {
286         MS_LOG(ERROR) << input_node->fullname_with_scope() << " Do cnode quant failed.";
287         return RET_ERROR;
288       }
289     } else if (input_node->isa<mindspore::Parameter>()) {
290       if (init_param_.weight_data_type_ == kTypeUnknown) {
291         MS_LOG(INFO) << "weight parameters do not need to be quantified.";
292         continue;
293       }
294       auto parameter_node = input_node->cast<ParameterPtr>();
295       ret = DoParameterNodeQuant(cnode, parameter_node, i);
296       if (ret == RET_NO_CHANGE) {
297         continue;
298       } else if (ret != RET_OK) {
299         MS_LOG(ERROR) << input_node->fullname_with_scope() << " Do parameter node quant failed.";
300         return ret;
301       }
302     } else if (input_node->isa<mindspore::ValueNode>()) {
303       ret = QuantValueNode(cnode, input_node, i);
304       if (ret == RET_NO_CHANGE) {
305         continue;
306       } else if (ret != RET_OK) {
307         MS_LOG(ERROR) << input_node->fullname_with_scope() << " Do Value node quant failed.";
308         return ret;
309       }
310     } else {
311       MS_LOG(ERROR) << input_node->fullname_with_scope() << ":" << input_node->type_name() << " is not support type";
312       return RET_ERROR;
313     }
314   }
315   return RET_OK;
316 }
317 
QuantNode(const FuncGraphPtr & func_graph)318 int FullQuantQuantizer::QuantNode(const FuncGraphPtr &func_graph) {
319   auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
320   auto outputs_diverg_info = calibrator_->GetOutputDivergInfo();
321 
322   auto cnodes = func_graph->GetOrderedCnodes();
323   for (const auto &cnode : cnodes) {
324     if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
325       continue;
326     }
327     auto op_name = cnode->fullname_with_scope();
328     MS_LOG(INFO) << "Quant node op name: " << op_name;
329     auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
330     if (primitive == nullptr) {
331       MS_LOG(ERROR) << "primitive is nullptr";
332       return RET_ERROR;
333     }
334     if (inputs_diverg_info->find(op_name) == inputs_diverg_info->end()) {
335       MS_LOG(INFO) << op_name << " can not do quant";
336       primitive->AddAttr(quant::kQuantType, MakeValue(static_cast<int>(quant::QUANT_NONE)));
337       continue;
338     }
339 
340     auto op_type = primitive->name();
341     if (op_type == mindspore::ops::kNameTupleGetItem) {
342       constexpr int tuple_get_item_input_size = 3;
343       MS_CHECK_TRUE_MSG(cnode->size() == tuple_get_item_input_size, RET_ERROR, "cnode->size() != 3");
344       auto index_node = cnode->input(THIRD_INPUT);
345       auto index_value_node = index_node->cast<mindspore::ValueNodePtr>();
346       if (index_value_node == nullptr) {
347         MS_LOG(WARNING) << "index value node is null";
348         continue;
349       }
350       size_t index = static_cast<size_t>(opt::CastToInt(index_value_node->value()).front());
351       auto input_node_quant_params = quant::GetInputNodeQuantParam(cnode, FIRST_INPUT + kPrimOffset, index);
352       std::vector<ValuePtr> quantization_list;
353       auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(input_node_quant_params);
354       if (quantization_ptr != nullptr) {
355         quantization_list.push_back(quantization_ptr);
356         primitive->AddAttr(quant::kQuantParam, std::make_shared<ValueList>(quantization_list));
357         primitive->AddAttr(quant::kQuantType, MakeValue(static_cast<int>(quant::QUANT_ALL)));
358       } else {
359         MS_LOG(WARNING) << cnode->fullname_with_scope() << "this TupleGetItem node's input_node_quant_params is empty.";
360       }
361       continue;
362     } else {  // do simple op quant
363       auto status = QuantNodeSimpleOp(cnode);
364       if (status != RET_OK) {
365         MS_LOG(ERROR) << "simple op quant failed.";
366         return status;
367       }
368     }
369     // do output quant, there may multi-output
370     auto &infos = (*outputs_diverg_info)[op_name];
371     std::vector<ValuePtr> quantization_list;
372     for (size_t index = 0; index < infos.size(); index++) {
373       auto &info = infos.at(index);
374       auto quant_params = GetQuantParam(cnode, info);
375       auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(quant_params);
376       if (quantization_ptr != nullptr) {
377         quantization_list.push_back(quantization_ptr);
378       }
379     }
380     primitive->AddAttr(quant::kQuantParam, std::make_shared<ValueList>(quantization_list));
381     primitive->AddAttr(quant::kQuantType, MakeValue(static_cast<int>(quant::QUANT_ALL)));
382   }
383   return RET_OK;
384 }
385 
UpdateDivergeInterval()386 int FullQuantQuantizer::UpdateDivergeInterval() {
387   auto ret = this->calibrator_->UpdateDivergInterval();
388   if (ret != RET_OK) {
389     MS_LOG(ERROR) << "Update input diverge interval failed.";
390     return ret;
391   }
392   return RET_OK;
393 }
394 
QuantWithKL()395 int FullQuantQuantizer::QuantWithKL() {
396   MS_LOG(INFO) << "start to update divergence's interval";
397   auto status = UpdateDivergeInterval();
398   if (status != RET_OK) {
399     MS_LOG(ERROR) << "Update diverge interval failed.";
400     return status;
401   }
402   MS_LOG(INFO) << "start to collect data's distribution";
403   status = DoInference(KL_BIN);
404   if (status != RET_OK) {
405     MS_LOG(ERROR) << "Collect data frequency failed.";
406     return status;
407   }
408   MS_LOG(INFO) << "compute the best threshold";
409   status = this->calibrator_->ComputeThreshold();
410   if (status != RET_OK) {
411     MS_LOG(ERROR) << "compute threshold failed.";
412     return status;
413   }
414   return RET_OK;
415 }
416 
InitCpuConfig()417 void FullQuantQuantizer::InitCpuConfig() {
418   init_param_.activation_quant_data_type_ = kNumberTypeInt8;
419   init_param_.activation_target_data_type_ = kNumberTypeInt8;
420   init_param_.weight_data_type_ = kNumberTypeInt8;
421   init_param_.activation_symmetric_ = false;
422   init_param_.weight_channel_symmetric_ = true;
423   init_param_.weight_layer_symmetric_ = false;
424   support_int8_ops_ = {
425     // Compute
426     prim::kPrimConv2DFusion,
427     prim::kPrimFullConnection,
428     prim::kPrimMatMulFusion,
429     // Memory
430     prim::kPrimReshape,
431     prim::kPrimTranspose,
432     prim::kPrimShape,
433     prim::kPrimUnsqueeze,
434   };
435   skip_check_dtype_ops_ = {prim::kPrimTupleGetItem, prim::kPrimShape};
436   per_channel_ops_ = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion, prim::kPrimMatMulFusion,
437                       prim::kPrimFullConnection, prim::kPrimLayerNormFusion};
438   support_activation_ = {
439     RELU, RELU6, HSWISH, SIGMOID, TANH,
440     // LEAKY_RELU must be symmetric.
441   };
442 }
443 
InitKirinConfig()444 void FullQuantQuantizer::InitKirinConfig() {
445   // `kTypeUnknown` represents the original data type
446   init_param_.activation_quant_data_type_ = kNumberTypeUInt8;
447   init_param_.activation_target_data_type_ = kTypeUnknown;
448   init_param_.weight_data_type_ = kNumberTypeInt8;
449   init_param_.activation_symmetric_ = false;
450   init_param_.weight_channel_symmetric_ = true;
451   init_param_.weight_layer_symmetric_ = false;
452   support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection};
453   param_->fullQuantParam.bias_correction = false;
454   per_channel_ops_ = {prim::kPrimConv2DFusion};
455 }
456 
InitNvGpuConfig()457 void FullQuantQuantizer::InitNvGpuConfig() {
458   // `kTypeUnknown` represents the original data type
459   init_param_.activation_target_data_type_ = kTypeUnknown;
460   init_param_.activation_symmetric_ = true;
461   init_param_.weight_data_type_ = kTypeUnknown;
462   init_param_.weight_channel_symmetric_ = true;
463   init_param_.weight_layer_symmetric_ = false;
464   support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimMatMul, prim::kPrimActivation,
465                        prim::kPrimConv2dTransposeFusion};
466   per_channel_ops_ = {};
467   param_->fullQuantParam.bias_correction = false;
468 }
469 
InitDSPConfig()470 void FullQuantQuantizer::InitDSPConfig() {
471   init_param_.activation_quant_data_type_ = kNumberTypeInt8;
472   init_param_.activation_target_data_type_ = kNumberTypeInt8;
473   init_param_.weight_data_type_ = kNumberTypeInt8;
474   init_param_.activation_symmetric_ = false;
475   init_param_.weight_channel_symmetric_ = true;
476   init_param_.weight_layer_symmetric_ = false;
477   support_int8_ops_ = {};
478   skip_check_dtype_ops_ = {prim::kPrimTupleGetItem, prim::kPrimShape};
479   per_channel_ops_ = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion};
480   support_activation_ = {RELU, RELU6, SIGMOID, TANH};
481 }
482 
InitAscendConfig()483 void FullQuantQuantizer::InitAscendConfig() {
484   // `kTypeUnknown` represents the original data type
485   init_param_.activation_quant_data_type_ = kNumberTypeInt8;
486   init_param_.activation_target_data_type_ = kNumberTypeInt8;  // It will update to Int32 in acl pass
487   init_param_.weight_data_type_ = kNumberTypeInt8;
488   init_param_.activation_symmetric_ = false;
489   init_param_.weight_channel_symmetric_ = true;
490   init_param_.weight_layer_symmetric_ = true;
491   support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimMatMulFusion};
492   per_channel_ops_ = {prim::kPrimConv2DFusion, prim::kPrimMatMulFusion};
493 }
494 
InitQMinMax()495 void FullQuantQuantizer::InitQMinMax() {
496   MS_ASSERT(init_param_.activation_quant_data_type_ == kNumberTypeInt8 ||
497             init_param_.activation_quant_data_type_ == kNumberTypeUInt8);
498   if (init_param_.activation_quant_data_type_ == kNumberTypeInt8) {
499     init_param_.activation_q_min_ = QuantMin(this->init_param_.bit_num_, false,
500                                              init_param_.activation_symmetric_);  // -128
501     init_param_.activation_q_max_ = QuantMax(this->init_param_.bit_num_, false);  // 127
502   } else if (init_param_.activation_quant_data_type_ == kNumberTypeUInt8) {
503     init_param_.activation_q_min_ = QuantMin(this->init_param_.bit_num_, true, false);  // 0
504     init_param_.activation_q_max_ = QuantMax(this->init_param_.bit_num_, true);         // 255
505   }
506   MS_ASSERT(init_param_.weight_data_type_ == kNumberTypeInt8 || init_param_.weight_data_type_ == kNumberTypeUInt8);
507   if (init_param_.weight_data_type_ == kNumberTypeInt8) {
508     init_param_.weight_channel_q_min_ = QuantMin(this->init_param_.bit_num_, false,
509                                                  init_param_.weight_channel_symmetric_);  // -127
510     init_param_.weight_channel_q_max_ = QuantMax(this->init_param_.bit_num_, false);      // 127
511   } else if (init_param_.activation_quant_data_type_ == kNumberTypeUInt8) {
512     init_param_.weight_channel_q_min_ = QuantMin(this->init_param_.bit_num_, true, false);  // 0
513     init_param_.weight_channel_q_max_ = QuantMax(this->init_param_.bit_num_, true);         // 255
514   }
515   if (init_param_.weight_data_type_ == kNumberTypeInt8) {
516     init_param_.weight_layer_q_min_ = QuantMin(this->init_param_.bit_num_, false,
517                                                init_param_.weight_layer_symmetric_);  // -128
518     init_param_.weight_layer_q_max_ = QuantMax(this->init_param_.bit_num_, false);    // 127
519   } else if (init_param_.activation_quant_data_type_ == kNumberTypeUInt8) {
520     init_param_.weight_layer_q_min_ = QuantMin(this->init_param_.bit_num_, true, false);  // 0
521     init_param_.weight_layer_q_max_ = QuantMax(this->init_param_.bit_num_, true);         // 255
522   }
523 }
524 
MarkQuantNode(const FuncGraphPtr & func_graph)525 int FullQuantQuantizer::MarkQuantNode(const FuncGraphPtr &func_graph) {
526   auto cnodes = func_graph->GetOrderedCnodes();
527   for (auto &cnode : cnodes) {
528     auto is_skip_op = quant_strategy_->IsSkipOp(cnode->fullname_with_scope());
529     if (is_skip_op) {
530       MS_LOG(INFO) << cnode->fullname_with_scope() << " is skip quant.";
531       continue;
532     }
533     //  Mark quantifiable nodes
534     auto is_support_op = quant_strategy_->CanOpFullQuantized(func_graph->manager(), cnode, support_int8_ops_,
535                                                              skip_check_dtype_ops_, support_activation_);
536     if (is_support_op) {
537       MS_LOG(INFO) << cnode->fullname_with_scope() << " mark quant.";
538       auto ret = calibrator_->AddQuantizedOp(cnode);
539       if (ret != RET_OK) {
540         MS_LOG(ERROR) << cnode->fullname_with_scope() << " add quantized op failed.";
541         return ret;
542       }
543     }
544   }
545   return RET_OK;
546 }
547 
InitDeviceConfig(const FuncGraphPtr & func_graph)548 int FullQuantQuantizer::InitDeviceConfig(const FuncGraphPtr &func_graph) {
549   switch (param_->fullQuantParam.target_device) {
550     case CPU:
551       InitCpuConfig();
552       break;
553     case KIRIN:
554       InitKirinConfig();
555       break;
556     case NVGPU:
557       InitNvGpuConfig();
558       break;
559     case DSP:
560       InitDSPConfig();
561       break;
562     case ASCEND:
563       InitAscendConfig();
564       break;
565     default:
566       MS_LOG(ERROR) << " Unsupported device " << param_->fullQuantParam.target_device;
567       return RET_ERROR;
568   }
569   InitQMinMax();
570   calibrator_ =
571     std::make_shared<Calibrator>(this->init_param_.bit_num_, init_param_.activation_q_max_,
572                                  init_param_.activation_q_min_, this->param_->fullQuantParam.activation_quant_method,
573                                  this->param_->dataPreProcessParam, init_param_.activation_symmetric_);
574   MSLITE_CHECK_PTR(calibrator_);
575   if (param_->fullQuantParam.target_device == ASCEND) {
576     quant_strategy_ = std::make_unique<QuantStrategy>(
577       param_->commonQuantParam.min_quant_weight_size, param_->commonQuantParam.min_quant_weight_channel,
578       param_->commonQuantParam.skip_quant_node, param_->fullQuantParam.target_device);
579   } else {
580     quant_strategy_ = std::make_unique<QuantStrategy>(0, 0, param_->commonQuantParam.skip_quant_node);
581   }
582 
583   CHECK_NULL_RETURN(quant_strategy_);
584   auto ret = MarkQuantNode(func_graph);
585   if (ret != RET_OK) {
586     MS_LOG(ERROR) << "Mark quant node failed.";
587     return ret;
588   }
589   return RET_OK;
590 }
591 
DoInference(CollectType collect_type)592 int FullQuantQuantizer::DoInference(CollectType collect_type) {
593   // get input tensor
594   vector<mindspore::MSTensor> inputs = fp32_ms_model_->GetInputs();
595   if (inputs.size() != calibrator_->GetInputNum()) {
596     MS_LOG(ERROR) << "model's input tensor count: " << inputs.size() << " != "
597                   << " calibrator count:" << calibrator_->GetInputNum();
598     return RET_ERROR;
599   }
600 
601   for (size_t calib_index = 0; calib_index < calibrator_->GetBatchNum(); calib_index++) {
602     MS_LOG(INFO) << "Do inference round: " << calib_index;
603     // set multi-input data
604     for (auto tensor : inputs) {
605       int status = calibrator_->GenerateInputData(tensor.Name(), calib_index, &tensor);
606       MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "generate input data from images failed!");
607     }
608     MSKernelCallBack beforeCallBack = [&](const std::vector<mindspore::MSTensor> &beforeInputs,
609                                           const std::vector<mindspore::MSTensor> &beforeOutputs,
610                                           const MSCallBackParam &callParam) -> bool {
611       auto diverg_info_map = calibrator_->GetInputDivergInfo();
612       // restore node name
613       auto node_names = SplitStringToVector(callParam.node_name, "_fusion");
614       MS_CHECK_TRUE_MSG(!node_names.empty(), false, "node_names is empty.");
615       if (node_names.empty()) {
616         MS_LOG(WARNING) << "node_names is empty, callParam.node_name: " << callParam.node_name;
617         return true;
618       }
619       auto ret = calibrator_->CollectDataDistribution(node_names.at(0), beforeInputs, diverg_info_map, collect_type);
620       if (ret != RET_OK) {
621         MS_LOG(ERROR) << "CollectDataDistribution failed.";
622         return false;
623       }
624       return true;
625     };
626     // func
627     MSKernelCallBack afterCallBack = [&](const std::vector<mindspore::MSTensor> &afterInputs,
628                                          const std::vector<mindspore::MSTensor> &afterOutputs,
629                                          const MSCallBackParam &callParam) -> bool {
630       auto diverg_info_map = calibrator_->GetOutputDivergInfo();
631       auto node_names = SplitStringToVector(callParam.node_name, "_fusion");
632       if (node_names.empty()) {
633         MS_LOG(WARNING) << "node_names is empty, callParam.node_name: " << callParam.node_name;
634         return true;
635       }
636       auto ret = calibrator_->CollectDataDistribution(node_names.at(0), afterOutputs, diverg_info_map, collect_type);
637       if (ret != RET_OK) {
638         MS_LOG(ERROR) << "CollectDataDistribution failed.";
639         return false;
640       }
641       return true;
642     };
643     auto outputs = fp32_ms_model_->GetOutputs();
644     auto status = fp32_ms_model_->Predict(inputs, &outputs, beforeCallBack, afterCallBack);
645     if (status != mindspore::kSuccess) {
646       MS_LOG(ERROR) << "run model failed!";
647       return RET_ERROR;
648     }
649   }
650   return RET_OK;
651 }
652 
DoQuantize(FuncGraphPtr func_graph)653 int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
654   MS_ASSERT(func_graph != nullptr);
655   MS_LOG(INFO) << "start to parse config file";
656   if (param_->dataPreProcessParam.calibrate_path.empty()) {
657     MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
658     return RET_INPUT_PARAM_INVALID;
659   }
660 
661   auto status = InitDeviceConfig(func_graph);
662   if (status != RET_OK) {
663     MS_LOG(ERROR) << "do pre process failed!";
664     return status;
665   }
666 
667   // anf -- fb
668   MS_LOG(INFO) << "start create session";
669   fp32_ms_model_ = std::make_shared<mindspore::Model>();
670   if (fp32_ms_model_ == nullptr) {
671     MS_LOG(ERROR) << "New model failed.";
672     return RET_ERROR;
673   }
674   size_t size = 0;
675   auto ret = BuildModelByFuncGraph(fp32_ms_model_, func_graph, param_, &size);
676   if (ret != mindspore::kSuccess) {
677     MS_LOG(ERROR) << "Build model failed.";
678     return RET_ERROR;
679   }
680   MS_LOG(INFO) << "start to update divergence's max value";
681   status = DoInference(MIN_MAX);
682   if (status != RET_OK) {
683     MS_LOG(ERROR) << "Do inference failed.";
684     return status;
685   }
686 
687   if (param_->fullQuantParam.activation_quant_method == KL) {
688     status = QuantWithKL();
689     if (status != RET_OK) {
690       MS_LOG(ERROR) << "Quant with KL failed.";
691       return status;
692     }
693   }
694 
695   MS_LOG(INFO) << "start to generate quant param and quantize tensor's data";
696   status = QuantNode(func_graph);
697   if (status != RET_OK) {
698     MS_LOG(ERROR) << "Quant node failed.";
699     return status;
700   }
701 
702   if (init_param_.activation_target_data_type_ == kNumberTypeInt8 ||
703       init_param_.activation_target_data_type_ == kNumberTypeUInt8) {  // ASCEND bias correction also need it.
704     // add quant_cast
705     for (auto &cnode : func_graph->GetOrderedCnodes()) {
706       quant::QuantType curr_quant_type;
707       if (GetQuantTypeNew(cnode, &curr_quant_type) != RET_OK) {
708         MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
709         return RET_ERROR;
710       }
711       quant::InsertQuantNodeManager insert_node_manager;
712       status = insert_node_manager.InsertCastNodeForFullQuant(func_graph, cnode, kNumberTypeFloat32, curr_quant_type);
713       if (status != RET_OK) {
714         MS_LOG(ERROR) << "InsertForwardCastNode failed, cnode name: " << cnode->fullname_with_scope();
715         ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
716         return status;
717       }
718     }
719   }
720 
721   if (this->param_->fullQuantParam.bias_correction) {
722     MS_LOG(INFO) << "do bias correction";
723     BiasCorrectionStrategy strategy(param_, calibrator_, quant_strategy_, fp32_ms_model_, init_param_.activation_q_min_,
724                                     init_param_.activation_q_max_);
725     status = strategy.DoBiasCorrection(func_graph);
726     if (status != RET_OK) {
727       MS_LOG(ERROR) << "Do bias correction failed.";
728       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
729       return RET_ERROR;
730     }
731   }
732   return RET_OK;
733 }
734 }  // namespace mindspore::lite::quant
735