1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18
19 #include "tools/converter/quantizer/full_quant_quantizer.h"
20 #include <dirent.h>
21 #include <memory>
22 #include <unordered_map>
23 #include <string>
24 #include <vector>
25 #include "mindspore/core/ops/sequence_ops.h"
26 #include "mindspore/core/ops/nn_ops.h"
27 #include "mindspore/core/ops/math_ops.h"
28 #include "mindspore/core/ops/lite_ops.h"
29 #include "mindspore/core/ops/array_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "ops/tuple_get_item.h"
32 #include "src/tensor.h"
33 #include "tools/converter/quantizer/insert_quant_node_manager.h"
34 #include "tools/converter/quantizer/quantize_util.h"
35 #include "tools/optimizer/common/gllo_utils.h"
36 #include "src/common/log_adapter.h"
37 #include "tools/common/tensor_util.h"
38 #include "src/common/utils.h"
39 #include "tools/common/node_util.h"
40 #include "nnacl/op_base.h"
41 #include "src/common/log_util.h"
42 #include "tools/converter/quantizer/bias_correction_strategy.h"
43
44 using std::string;
45 using std::vector;
46
47 namespace mindspore::lite::quant {
~FullQuantQuantizer()48 FullQuantQuantizer::~FullQuantQuantizer() {}
49
GetQuantParam(const AnfNodePtr & input_node,const std::unique_ptr<DataDistribution> & info) const50 std::vector<schema::QuantParamT> FullQuantQuantizer::GetQuantParam(
51 const AnfNodePtr &input_node, const std::unique_ptr<DataDistribution> &info) const {
52 std::vector<schema::QuantParamT> quant_params;
53 schema::QuantParamT quant_param;
54 TypeId type_id = kTypeUnknown;
55 if (opt::GetDataTypeFromAnfNode(input_node, &type_id) != RET_OK) {
56 MS_LOG(ERROR) << "Get data type failed.";
57 return quant_params;
58 }
59 if (type_id == kNumberTypeFloat32 && info != nullptr) {
60 quant_param.scale = info->GetScale();
61 quant_param.zeroPoint = info->GetZeroPoint();
62 quant_param.max = info->GetEncodeMax();
63 quant_param.min = info->GetEncodeMin();
64 quant_param.numBits = init_param_.bit_num_;
65 quant_param.narrowRange = true;
66 quant_param.inited = true;
67 quant_param.roundType = 1;
68 quant_param.multiplier = 1;
69 quant_params.push_back(quant_param);
70 }
71 return quant_params;
72 }
73
QuantWeight(const CNodePtr & cnode,const PrimitivePtr & primitive,const AnfNodePtr & weight,int input_index,const tensor::TensorPtr & tensor_info,bool per_channel)74 int FullQuantQuantizer::QuantWeight(const CNodePtr &cnode, const PrimitivePtr &primitive, const AnfNodePtr &weight,
75 int input_index, const tensor::TensorPtr &tensor_info, bool per_channel) {
76 int preferred_dim = GetPreferredDim(cnode, input_index - 1, ConvertShapeVectorToInt32(tensor_info->shape()));
77 auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
78 auto weight_q_min = per_channel ? init_param_.weight_channel_q_min_ : init_param_.weight_layer_q_min_;
79 auto weight_q_max = per_channel ? init_param_.weight_channel_q_max_ : init_param_.weight_layer_q_max_;
80 auto symmetric = per_channel ? init_param_.weight_channel_symmetric_ : init_param_.weight_layer_symmetric_;
81 return fixed_bit_quant_.QuantFilter(weight, tensor_info, primitive, quant::QUANT_ALL, weight_q_max, weight_q_min,
82 init_param_.bit_num_, weight_quant_type, kNumberTypeInt8, preferred_dim,
83 symmetric);
84 }
85
DoParameterWeightQuant(const CNodePtr & cnode,const ParameterPtr & weight,const PrimitivePtr & primitive,int input_index,bool per_channel)86 int FullQuantQuantizer::DoParameterWeightQuant(const CNodePtr &cnode, const ParameterPtr &weight,
87 const PrimitivePtr &primitive, int input_index, bool per_channel) {
88 CHECK_NULL_RETURN(cnode);
89 CHECK_NULL_RETURN(weight);
90 CHECK_NULL_RETURN(primitive);
91 auto tensor_info = weight->default_param()->cast<tensor::TensorPtr>();
92 if (tensor_info == nullptr) {
93 MS_LOG(ERROR) << weight->fullname_with_scope() << " can't get value";
94 return RET_NULL_PTR;
95 }
96 return QuantWeight(cnode, primitive, weight, input_index, tensor_info, per_channel);
97 }
98
DoValueNodeWeightQuant(const CNodePtr & cnode,const ValueNodePtr & weight,const PrimitivePtr & primitive,int input_index,bool per_channel)99 int FullQuantQuantizer::DoValueNodeWeightQuant(const CNodePtr &cnode, const ValueNodePtr &weight,
100 const PrimitivePtr &primitive, int input_index, bool per_channel) {
101 CHECK_NULL_RETURN(weight);
102 CHECK_NULL_RETURN(primitive);
103 auto tensor_info = weight->value()->cast<tensor::TensorPtr>();
104 if (tensor_info == nullptr) {
105 MS_LOG(ERROR) << weight->fullname_with_scope() << " can't get value";
106 return RET_NULL_PTR;
107 }
108 return QuantWeight(cnode, primitive, weight, input_index, tensor_info, per_channel);
109 }
110
IsSupportWeightQuant(const AnfNodePtr & input_node)111 int FullQuantQuantizer::IsSupportWeightQuant(const AnfNodePtr &input_node) {
112 TypeId type_id = kTypeUnknown;
113 if (opt::GetDataTypeFromAnfNode(input_node, &type_id) != RET_OK) {
114 MS_LOG(ERROR) << "Get data type failed.";
115 return RET_ERROR;
116 }
117 // support for share weight.
118 if (type_id == kNumberTypeInt8) {
119 auto iter = weight_quant_params_bak_.find(input_node->fullname_with_scope());
120 if (iter == weight_quant_params_bak_.end()) {
121 return RET_ERROR;
122 } else {
123 return RET_NO_CHANGE;
124 }
125 }
126 // Only data the data type is fp32 can be quant.
127 if (type_id != kNumberTypeFloat32) {
128 return RET_NO_CHANGE;
129 }
130 return RET_OK;
131 }
132
DoParameterNodeQuant(const CNodePtr & cnode,const ParameterPtr & input_node,size_t input_index)133 int FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const ParameterPtr &input_node,
134 size_t input_index) {
135 auto ret = IsSupportWeightQuant(input_node);
136 if (ret != RET_OK) {
137 return ret;
138 }
139 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
140 CHECK_NULL_RETURN(primitive);
141 auto op_name = cnode->fullname_with_scope();
142 if (input_index == THIRD_INPUT + kPrimOffset && CheckNodeInSet(cnode, kHasBiasOperator)) {
143 auto weight_parameter = cnode->input(SECOND_INPUT + kPrimOffset)->cast<ParameterPtr>();
144 auto active_quant_params = quant::GetInputNodeQuantParam(cnode, FIRST_INPUT + kPrimOffset);
145 ret = fixed_bit_quant_.QuantBias(weight_parameter, input_node, active_quant_params);
146 if (ret != RET_OK) {
147 MS_LOG(ERROR) << op_name << " Do bias quant failed.";
148 return ret;
149 }
150 } else if (param_->fullQuantParam.per_channel && CheckNodeInSet(cnode, per_channel_ops_)) {
151 ret = DoParameterWeightQuant(cnode, input_node, primitive, input_index, true);
152 if (ret != RET_OK) {
153 MS_LOG(ERROR) << op_name << " Do bias quant failed.";
154 return ret;
155 }
156 } else {
157 ret = DoParameterWeightQuant(cnode, input_node, primitive, input_index, false);
158 if (ret != RET_OK) {
159 MS_LOG(ERROR) << op_name << " Do bias quant failed.";
160 return ret;
161 }
162 }
163 // support shared weight
164 auto tensor_info = input_node->default_param()->cast<tensor::TensorPtr>();
165 if (tensor_info->quant_params().empty()) {
166 return RET_NO_CHANGE;
167 }
168 auto quant_params = quant::ConvertQuantizationParamToQuantParamT(tensor_info->quant_params().front());
169 weight_quant_params_bak_[input_node->fullname_with_scope()] = quant_params;
170 return RET_OK;
171 }
172
DoValueNodeQuant(const CNodePtr & cnode,const ValueNodePtr & input_node,size_t input_index)173 int FullQuantQuantizer::DoValueNodeQuant(const CNodePtr &cnode, const ValueNodePtr &input_node, size_t input_index) {
174 auto ret = IsSupportWeightQuant(input_node);
175 if (ret != RET_OK) {
176 return ret;
177 }
178 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
179 CHECK_NULL_RETURN(primitive);
180 auto op_name = cnode->fullname_with_scope();
181 ret = DoValueNodeWeightQuant(cnode, input_node, primitive, input_index, false);
182 if (ret != RET_OK) {
183 MS_LOG(ERROR) << op_name << " Do value node weight quant failed.";
184 return ret;
185 }
186 return RET_OK;
187 }
188
QuantNodeGraphInput(const PrimitivePtr & primitive,const AnfNodePtr & input_node,const std::unique_ptr<DataDistribution> & info)189 int FullQuantQuantizer::QuantNodeGraphInput(const PrimitivePtr &primitive, const AnfNodePtr &input_node,
190 const std::unique_ptr<DataDistribution> &info) {
191 TypeId type_id = kTypeUnknown;
192 if (opt::GetDataTypeFromAnfNode(input_node, &type_id) != RET_OK) {
193 MS_LOG(ERROR) << "Get data type failed.";
194 return RET_ERROR;
195 }
196 if (type_id == kNumberTypeFloat32 && info != nullptr) {
197 schema::QuantParamT quant_param;
198 quant_param.scale = info->GetScale();
199 quant_param.zeroPoint = info->GetZeroPoint();
200 quant_param.max = info->GetEncodeMax();
201 quant_param.min = info->GetEncodeMin();
202 quant_param.numBits = static_cast<int32_t>(init_param_.bit_num_);
203 quant_param.narrowRange = true;
204 quant_param.inited = true;
205 quant_param.roundType = 1;
206 quant_param.multiplier = 1;
207 auto quantization_param = quant::ConvertQuantParamTToQuantizationParam({quant_param});
208 primitive->AddAttr(quant::kGraphInputQuantParam, quantization_param);
209 }
210 return RET_OK;
211 }
212
QuantNodeCNode(const CNodePtr & cnode,const AnfNodePtr & input_node,const std::unique_ptr<DataDistribution> & info)213 int FullQuantQuantizer::QuantNodeCNode(const CNodePtr &cnode, const AnfNodePtr &input_node,
214 const std::unique_ptr<DataDistribution> &info) {
215 auto input_cnode = input_node->cast<mindspore::CNodePtr>();
216 MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_NULL_PTR, "input_cnode is nullptr.");
217 auto input_cnode_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
218 CHECK_NULL_RETURN(input_cnode_primitive);
219 if (input_cnode_primitive->HasAttr(quant::kQuantParam)) {
220 MS_LOG(INFO) << input_node->fullname_with_scope() << " quant param already exist.";
221 return RET_NO_CHANGE;
222 }
223 auto quant_params = GetQuantParam(input_node, info);
224 if (quant_params.empty()) {
225 MS_LOG(INFO) << input_node->fullname_with_scope() << " quant param not exist.";
226 return RET_NO_CHANGE;
227 }
228 auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(quant_params);
229 if (quantization_ptr != nullptr) {
230 std::vector<ValuePtr> quantization_list = {quantization_ptr};
231 input_cnode_primitive->AddAttr(quant::kQuantParam, std::make_shared<ValueList>(quantization_list));
232 }
233 return RET_OK;
234 }
235
QuantValueNode(const CNodePtr & cnode,const AnfNodePtr & input_node,size_t i)236 int FullQuantQuantizer::QuantValueNode(const CNodePtr &cnode, const AnfNodePtr &input_node, size_t i) {
237 if (init_param_.weight_data_type_ == kTypeUnknown) {
238 MS_LOG(INFO) << "weight parameters do not need to be quantified.";
239 return RET_NO_CHANGE;
240 }
241 auto value_node = input_node->cast<ValueNodePtr>();
242 auto ret = DoValueNodeQuant(cnode, value_node, i);
243 if (ret == RET_NO_CHANGE) {
244 return RET_NO_CHANGE;
245 } else if (ret != RET_OK) {
246 MS_LOG(ERROR) << input_node->fullname_with_scope() << " Do value node quant failed.";
247 return ret;
248 }
249 // support shared weight
250 auto tensor_info = value_node->value()->cast<tensor::TensorPtr>();
251 if (tensor_info->quant_params().empty()) {
252 return RET_NO_CHANGE;
253 }
254 auto quant_params = quant::ConvertQuantizationParamToQuantParamT(tensor_info->quant_params().front());
255 weight_quant_params_bak_[input_node->fullname_with_scope()] = quant_params;
256 return RET_OK;
257 }
258
QuantNodeSimpleOp(const CNodePtr & cnode)259 int FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
260 MS_ASSERT(cnode != nullptr);
261 auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
262 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
263 CHECK_NULL_RETURN(primitive);
264 auto op_name = cnode->fullname_with_scope();
265 MS_ASSERT(cnode->size() - 1 <= (*inputs_diverg_info)[op_name].size());
266 int ret;
267 for (size_t i = 1; i < cnode->size(); i++) {
268 auto input_node = cnode->input(i);
269 CHECK_NULL_RETURN(input_node);
270 bool is_graph_input = IsGraphInput(input_node);
271 if (is_graph_input) {
272 // do input quant
273 auto &info = (*inputs_diverg_info)[op_name][i - 1];
274 if (info == nullptr) {
275 MS_LOG(INFO) << input_node->fullname_with_scope() << " quant info not exist.";
276 continue;
277 }
278 if (QuantNodeGraphInput(primitive, input_node, info) != RET_OK) {
279 MS_LOG(ERROR) << input_node->fullname_with_scope() << " Do graph input node quant failed.";
280 return RET_ERROR;
281 }
282 } else if (input_node->isa<mindspore::CNode>()) {
283 auto &info = (*inputs_diverg_info)[op_name][i - 1];
284 ret = QuantNodeCNode(cnode, input_node, info);
285 if (ret != RET_NO_CHANGE && ret != RET_OK) {
286 MS_LOG(ERROR) << input_node->fullname_with_scope() << " Do cnode quant failed.";
287 return RET_ERROR;
288 }
289 } else if (input_node->isa<mindspore::Parameter>()) {
290 if (init_param_.weight_data_type_ == kTypeUnknown) {
291 MS_LOG(INFO) << "weight parameters do not need to be quantified.";
292 continue;
293 }
294 auto parameter_node = input_node->cast<ParameterPtr>();
295 ret = DoParameterNodeQuant(cnode, parameter_node, i);
296 if (ret == RET_NO_CHANGE) {
297 continue;
298 } else if (ret != RET_OK) {
299 MS_LOG(ERROR) << input_node->fullname_with_scope() << " Do parameter node quant failed.";
300 return ret;
301 }
302 } else if (input_node->isa<mindspore::ValueNode>()) {
303 ret = QuantValueNode(cnode, input_node, i);
304 if (ret == RET_NO_CHANGE) {
305 continue;
306 } else if (ret != RET_OK) {
307 MS_LOG(ERROR) << input_node->fullname_with_scope() << " Do Value node quant failed.";
308 return ret;
309 }
310 } else {
311 MS_LOG(ERROR) << input_node->fullname_with_scope() << ":" << input_node->type_name() << " is not support type";
312 return RET_ERROR;
313 }
314 }
315 return RET_OK;
316 }
317
QuantNode(const FuncGraphPtr & func_graph)318 int FullQuantQuantizer::QuantNode(const FuncGraphPtr &func_graph) {
319 auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
320 auto outputs_diverg_info = calibrator_->GetOutputDivergInfo();
321
322 auto cnodes = func_graph->GetOrderedCnodes();
323 for (const auto &cnode : cnodes) {
324 if (opt::CheckPrimitiveType(cnode, prim::kPrimQuantDTypeCast)) {
325 continue;
326 }
327 auto op_name = cnode->fullname_with_scope();
328 MS_LOG(INFO) << "Quant node op name: " << op_name;
329 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
330 if (primitive == nullptr) {
331 MS_LOG(ERROR) << "primitive is nullptr";
332 return RET_ERROR;
333 }
334 if (inputs_diverg_info->find(op_name) == inputs_diverg_info->end()) {
335 MS_LOG(INFO) << op_name << " can not do quant";
336 primitive->AddAttr(quant::kQuantType, MakeValue(static_cast<int>(quant::QUANT_NONE)));
337 continue;
338 }
339
340 auto op_type = primitive->name();
341 if (op_type == mindspore::ops::kNameTupleGetItem) {
342 constexpr int tuple_get_item_input_size = 3;
343 MS_CHECK_TRUE_MSG(cnode->size() == tuple_get_item_input_size, RET_ERROR, "cnode->size() != 3");
344 auto index_node = cnode->input(THIRD_INPUT);
345 auto index_value_node = index_node->cast<mindspore::ValueNodePtr>();
346 if (index_value_node == nullptr) {
347 MS_LOG(WARNING) << "index value node is null";
348 continue;
349 }
350 size_t index = static_cast<size_t>(opt::CastToInt(index_value_node->value()).front());
351 auto input_node_quant_params = quant::GetInputNodeQuantParam(cnode, FIRST_INPUT + kPrimOffset, index);
352 std::vector<ValuePtr> quantization_list;
353 auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(input_node_quant_params);
354 if (quantization_ptr != nullptr) {
355 quantization_list.push_back(quantization_ptr);
356 primitive->AddAttr(quant::kQuantParam, std::make_shared<ValueList>(quantization_list));
357 primitive->AddAttr(quant::kQuantType, MakeValue(static_cast<int>(quant::QUANT_ALL)));
358 } else {
359 MS_LOG(WARNING) << cnode->fullname_with_scope() << "this TupleGetItem node's input_node_quant_params is empty.";
360 }
361 continue;
362 } else { // do simple op quant
363 auto status = QuantNodeSimpleOp(cnode);
364 if (status != RET_OK) {
365 MS_LOG(ERROR) << "simple op quant failed.";
366 return status;
367 }
368 }
369 // do output quant, there may multi-output
370 auto &infos = (*outputs_diverg_info)[op_name];
371 std::vector<ValuePtr> quantization_list;
372 for (size_t index = 0; index < infos.size(); index++) {
373 auto &info = infos.at(index);
374 auto quant_params = GetQuantParam(cnode, info);
375 auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(quant_params);
376 if (quantization_ptr != nullptr) {
377 quantization_list.push_back(quantization_ptr);
378 }
379 }
380 primitive->AddAttr(quant::kQuantParam, std::make_shared<ValueList>(quantization_list));
381 primitive->AddAttr(quant::kQuantType, MakeValue(static_cast<int>(quant::QUANT_ALL)));
382 }
383 return RET_OK;
384 }
385
UpdateDivergeInterval()386 int FullQuantQuantizer::UpdateDivergeInterval() {
387 auto ret = this->calibrator_->UpdateDivergInterval();
388 if (ret != RET_OK) {
389 MS_LOG(ERROR) << "Update input diverge interval failed.";
390 return ret;
391 }
392 return RET_OK;
393 }
394
QuantWithKL()395 int FullQuantQuantizer::QuantWithKL() {
396 MS_LOG(INFO) << "start to update divergence's interval";
397 auto status = UpdateDivergeInterval();
398 if (status != RET_OK) {
399 MS_LOG(ERROR) << "Update diverge interval failed.";
400 return status;
401 }
402 MS_LOG(INFO) << "start to collect data's distribution";
403 status = DoInference(KL_BIN);
404 if (status != RET_OK) {
405 MS_LOG(ERROR) << "Collect data frequency failed.";
406 return status;
407 }
408 MS_LOG(INFO) << "compute the best threshold";
409 status = this->calibrator_->ComputeThreshold();
410 if (status != RET_OK) {
411 MS_LOG(ERROR) << "compute threshold failed.";
412 return status;
413 }
414 return RET_OK;
415 }
416
InitCpuConfig()417 void FullQuantQuantizer::InitCpuConfig() {
418 init_param_.activation_quant_data_type_ = kNumberTypeInt8;
419 init_param_.activation_target_data_type_ = kNumberTypeInt8;
420 init_param_.weight_data_type_ = kNumberTypeInt8;
421 init_param_.activation_symmetric_ = false;
422 init_param_.weight_channel_symmetric_ = true;
423 init_param_.weight_layer_symmetric_ = false;
424 support_int8_ops_ = {
425 // Compute
426 prim::kPrimConv2DFusion,
427 prim::kPrimFullConnection,
428 prim::kPrimMatMulFusion,
429 // Memory
430 prim::kPrimReshape,
431 prim::kPrimTranspose,
432 prim::kPrimShape,
433 prim::kPrimUnsqueeze,
434 };
435 skip_check_dtype_ops_ = {prim::kPrimTupleGetItem, prim::kPrimShape};
436 per_channel_ops_ = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion, prim::kPrimMatMulFusion,
437 prim::kPrimFullConnection, prim::kPrimLayerNormFusion};
438 support_activation_ = {
439 RELU, RELU6, HSWISH, SIGMOID, TANH,
440 // LEAKY_RELU must be symmetric.
441 };
442 }
443
InitKirinConfig()444 void FullQuantQuantizer::InitKirinConfig() {
445 // `kTypeUnknown` represents the original data type
446 init_param_.activation_quant_data_type_ = kNumberTypeUInt8;
447 init_param_.activation_target_data_type_ = kTypeUnknown;
448 init_param_.weight_data_type_ = kNumberTypeInt8;
449 init_param_.activation_symmetric_ = false;
450 init_param_.weight_channel_symmetric_ = true;
451 init_param_.weight_layer_symmetric_ = false;
452 support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimFullConnection};
453 param_->fullQuantParam.bias_correction = false;
454 per_channel_ops_ = {prim::kPrimConv2DFusion};
455 }
456
InitNvGpuConfig()457 void FullQuantQuantizer::InitNvGpuConfig() {
458 // `kTypeUnknown` represents the original data type
459 init_param_.activation_target_data_type_ = kTypeUnknown;
460 init_param_.activation_symmetric_ = true;
461 init_param_.weight_data_type_ = kTypeUnknown;
462 init_param_.weight_channel_symmetric_ = true;
463 init_param_.weight_layer_symmetric_ = false;
464 support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimMatMul, prim::kPrimActivation,
465 prim::kPrimConv2dTransposeFusion};
466 per_channel_ops_ = {};
467 param_->fullQuantParam.bias_correction = false;
468 }
469
InitDSPConfig()470 void FullQuantQuantizer::InitDSPConfig() {
471 init_param_.activation_quant_data_type_ = kNumberTypeInt8;
472 init_param_.activation_target_data_type_ = kNumberTypeInt8;
473 init_param_.weight_data_type_ = kNumberTypeInt8;
474 init_param_.activation_symmetric_ = false;
475 init_param_.weight_channel_symmetric_ = true;
476 init_param_.weight_layer_symmetric_ = false;
477 support_int8_ops_ = {};
478 skip_check_dtype_ops_ = {prim::kPrimTupleGetItem, prim::kPrimShape};
479 per_channel_ops_ = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion};
480 support_activation_ = {RELU, RELU6, SIGMOID, TANH};
481 }
482
InitAscendConfig()483 void FullQuantQuantizer::InitAscendConfig() {
484 // `kTypeUnknown` represents the original data type
485 init_param_.activation_quant_data_type_ = kNumberTypeInt8;
486 init_param_.activation_target_data_type_ = kNumberTypeInt8; // It will update to Int32 in acl pass
487 init_param_.weight_data_type_ = kNumberTypeInt8;
488 init_param_.activation_symmetric_ = false;
489 init_param_.weight_channel_symmetric_ = true;
490 init_param_.weight_layer_symmetric_ = true;
491 support_int8_ops_ = {prim::kPrimConv2DFusion, prim::kPrimMatMulFusion};
492 per_channel_ops_ = {prim::kPrimConv2DFusion, prim::kPrimMatMulFusion};
493 }
494
InitQMinMax()495 void FullQuantQuantizer::InitQMinMax() {
496 MS_ASSERT(init_param_.activation_quant_data_type_ == kNumberTypeInt8 ||
497 init_param_.activation_quant_data_type_ == kNumberTypeUInt8);
498 if (init_param_.activation_quant_data_type_ == kNumberTypeInt8) {
499 init_param_.activation_q_min_ = QuantMin(this->init_param_.bit_num_, false,
500 init_param_.activation_symmetric_); // -128
501 init_param_.activation_q_max_ = QuantMax(this->init_param_.bit_num_, false); // 127
502 } else if (init_param_.activation_quant_data_type_ == kNumberTypeUInt8) {
503 init_param_.activation_q_min_ = QuantMin(this->init_param_.bit_num_, true, false); // 0
504 init_param_.activation_q_max_ = QuantMax(this->init_param_.bit_num_, true); // 255
505 }
506 MS_ASSERT(init_param_.weight_data_type_ == kNumberTypeInt8 || init_param_.weight_data_type_ == kNumberTypeUInt8);
507 if (init_param_.weight_data_type_ == kNumberTypeInt8) {
508 init_param_.weight_channel_q_min_ = QuantMin(this->init_param_.bit_num_, false,
509 init_param_.weight_channel_symmetric_); // -127
510 init_param_.weight_channel_q_max_ = QuantMax(this->init_param_.bit_num_, false); // 127
511 } else if (init_param_.activation_quant_data_type_ == kNumberTypeUInt8) {
512 init_param_.weight_channel_q_min_ = QuantMin(this->init_param_.bit_num_, true, false); // 0
513 init_param_.weight_channel_q_max_ = QuantMax(this->init_param_.bit_num_, true); // 255
514 }
515 if (init_param_.weight_data_type_ == kNumberTypeInt8) {
516 init_param_.weight_layer_q_min_ = QuantMin(this->init_param_.bit_num_, false,
517 init_param_.weight_layer_symmetric_); // -128
518 init_param_.weight_layer_q_max_ = QuantMax(this->init_param_.bit_num_, false); // 127
519 } else if (init_param_.activation_quant_data_type_ == kNumberTypeUInt8) {
520 init_param_.weight_layer_q_min_ = QuantMin(this->init_param_.bit_num_, true, false); // 0
521 init_param_.weight_layer_q_max_ = QuantMax(this->init_param_.bit_num_, true); // 255
522 }
523 }
524
MarkQuantNode(const FuncGraphPtr & func_graph)525 int FullQuantQuantizer::MarkQuantNode(const FuncGraphPtr &func_graph) {
526 auto cnodes = func_graph->GetOrderedCnodes();
527 for (auto &cnode : cnodes) {
528 auto is_skip_op = quant_strategy_->IsSkipOp(cnode->fullname_with_scope());
529 if (is_skip_op) {
530 MS_LOG(INFO) << cnode->fullname_with_scope() << " is skip quant.";
531 continue;
532 }
533 // Mark quantifiable nodes
534 auto is_support_op = quant_strategy_->CanOpFullQuantized(func_graph->manager(), cnode, support_int8_ops_,
535 skip_check_dtype_ops_, support_activation_);
536 if (is_support_op) {
537 MS_LOG(INFO) << cnode->fullname_with_scope() << " mark quant.";
538 auto ret = calibrator_->AddQuantizedOp(cnode);
539 if (ret != RET_OK) {
540 MS_LOG(ERROR) << cnode->fullname_with_scope() << " add quantized op failed.";
541 return ret;
542 }
543 }
544 }
545 return RET_OK;
546 }
547
InitDeviceConfig(const FuncGraphPtr & func_graph)548 int FullQuantQuantizer::InitDeviceConfig(const FuncGraphPtr &func_graph) {
549 switch (param_->fullQuantParam.target_device) {
550 case CPU:
551 InitCpuConfig();
552 break;
553 case KIRIN:
554 InitKirinConfig();
555 break;
556 case NVGPU:
557 InitNvGpuConfig();
558 break;
559 case DSP:
560 InitDSPConfig();
561 break;
562 case ASCEND:
563 InitAscendConfig();
564 break;
565 default:
566 MS_LOG(ERROR) << " Unsupported device " << param_->fullQuantParam.target_device;
567 return RET_ERROR;
568 }
569 InitQMinMax();
570 calibrator_ =
571 std::make_shared<Calibrator>(this->init_param_.bit_num_, init_param_.activation_q_max_,
572 init_param_.activation_q_min_, this->param_->fullQuantParam.activation_quant_method,
573 this->param_->dataPreProcessParam, init_param_.activation_symmetric_);
574 MSLITE_CHECK_PTR(calibrator_);
575 if (param_->fullQuantParam.target_device == ASCEND) {
576 quant_strategy_ = std::make_unique<QuantStrategy>(
577 param_->commonQuantParam.min_quant_weight_size, param_->commonQuantParam.min_quant_weight_channel,
578 param_->commonQuantParam.skip_quant_node, param_->fullQuantParam.target_device);
579 } else {
580 quant_strategy_ = std::make_unique<QuantStrategy>(0, 0, param_->commonQuantParam.skip_quant_node);
581 }
582
583 CHECK_NULL_RETURN(quant_strategy_);
584 auto ret = MarkQuantNode(func_graph);
585 if (ret != RET_OK) {
586 MS_LOG(ERROR) << "Mark quant node failed.";
587 return ret;
588 }
589 return RET_OK;
590 }
591
DoInference(CollectType collect_type)592 int FullQuantQuantizer::DoInference(CollectType collect_type) {
593 // get input tensor
594 vector<mindspore::MSTensor> inputs = fp32_ms_model_->GetInputs();
595 if (inputs.size() != calibrator_->GetInputNum()) {
596 MS_LOG(ERROR) << "model's input tensor count: " << inputs.size() << " != "
597 << " calibrator count:" << calibrator_->GetInputNum();
598 return RET_ERROR;
599 }
600
601 for (size_t calib_index = 0; calib_index < calibrator_->GetBatchNum(); calib_index++) {
602 MS_LOG(INFO) << "Do inference round: " << calib_index;
603 // set multi-input data
604 for (auto tensor : inputs) {
605 int status = calibrator_->GenerateInputData(tensor.Name(), calib_index, &tensor);
606 MS_CHECK_TRUE_MSG(status == RET_OK, RET_ERROR, "generate input data from images failed!");
607 }
608 MSKernelCallBack beforeCallBack = [&](const std::vector<mindspore::MSTensor> &beforeInputs,
609 const std::vector<mindspore::MSTensor> &beforeOutputs,
610 const MSCallBackParam &callParam) -> bool {
611 auto diverg_info_map = calibrator_->GetInputDivergInfo();
612 // restore node name
613 auto node_names = SplitStringToVector(callParam.node_name, "_fusion");
614 MS_CHECK_TRUE_MSG(!node_names.empty(), false, "node_names is empty.");
615 if (node_names.empty()) {
616 MS_LOG(WARNING) << "node_names is empty, callParam.node_name: " << callParam.node_name;
617 return true;
618 }
619 auto ret = calibrator_->CollectDataDistribution(node_names.at(0), beforeInputs, diverg_info_map, collect_type);
620 if (ret != RET_OK) {
621 MS_LOG(ERROR) << "CollectDataDistribution failed.";
622 return false;
623 }
624 return true;
625 };
626 // func
627 MSKernelCallBack afterCallBack = [&](const std::vector<mindspore::MSTensor> &afterInputs,
628 const std::vector<mindspore::MSTensor> &afterOutputs,
629 const MSCallBackParam &callParam) -> bool {
630 auto diverg_info_map = calibrator_->GetOutputDivergInfo();
631 auto node_names = SplitStringToVector(callParam.node_name, "_fusion");
632 if (node_names.empty()) {
633 MS_LOG(WARNING) << "node_names is empty, callParam.node_name: " << callParam.node_name;
634 return true;
635 }
636 auto ret = calibrator_->CollectDataDistribution(node_names.at(0), afterOutputs, diverg_info_map, collect_type);
637 if (ret != RET_OK) {
638 MS_LOG(ERROR) << "CollectDataDistribution failed.";
639 return false;
640 }
641 return true;
642 };
643 auto outputs = fp32_ms_model_->GetOutputs();
644 auto status = fp32_ms_model_->Predict(inputs, &outputs, beforeCallBack, afterCallBack);
645 if (status != mindspore::kSuccess) {
646 MS_LOG(ERROR) << "run model failed!";
647 return RET_ERROR;
648 }
649 }
650 return RET_OK;
651 }
652
DoQuantize(FuncGraphPtr func_graph)653 int FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
654 MS_ASSERT(func_graph != nullptr);
655 MS_LOG(INFO) << "start to parse config file";
656 if (param_->dataPreProcessParam.calibrate_path.empty()) {
657 MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
658 return RET_INPUT_PARAM_INVALID;
659 }
660
661 auto status = InitDeviceConfig(func_graph);
662 if (status != RET_OK) {
663 MS_LOG(ERROR) << "do pre process failed!";
664 return status;
665 }
666
667 // anf -- fb
668 MS_LOG(INFO) << "start create session";
669 fp32_ms_model_ = std::make_shared<mindspore::Model>();
670 if (fp32_ms_model_ == nullptr) {
671 MS_LOG(ERROR) << "New model failed.";
672 return RET_ERROR;
673 }
674 size_t size = 0;
675 auto ret = BuildModelByFuncGraph(fp32_ms_model_, func_graph, param_, &size);
676 if (ret != mindspore::kSuccess) {
677 MS_LOG(ERROR) << "Build model failed.";
678 return RET_ERROR;
679 }
680 MS_LOG(INFO) << "start to update divergence's max value";
681 status = DoInference(MIN_MAX);
682 if (status != RET_OK) {
683 MS_LOG(ERROR) << "Do inference failed.";
684 return status;
685 }
686
687 if (param_->fullQuantParam.activation_quant_method == KL) {
688 status = QuantWithKL();
689 if (status != RET_OK) {
690 MS_LOG(ERROR) << "Quant with KL failed.";
691 return status;
692 }
693 }
694
695 MS_LOG(INFO) << "start to generate quant param and quantize tensor's data";
696 status = QuantNode(func_graph);
697 if (status != RET_OK) {
698 MS_LOG(ERROR) << "Quant node failed.";
699 return status;
700 }
701
702 if (init_param_.activation_target_data_type_ == kNumberTypeInt8 ||
703 init_param_.activation_target_data_type_ == kNumberTypeUInt8) { // ASCEND bias correction also need it.
704 // add quant_cast
705 for (auto &cnode : func_graph->GetOrderedCnodes()) {
706 quant::QuantType curr_quant_type;
707 if (GetQuantTypeNew(cnode, &curr_quant_type) != RET_OK) {
708 MS_LOG(ERROR) << "Get quant type failed, cnode name: " << cnode->fullname_with_scope();
709 return RET_ERROR;
710 }
711 quant::InsertQuantNodeManager insert_node_manager;
712 status = insert_node_manager.InsertCastNodeForFullQuant(func_graph, cnode, kNumberTypeFloat32, curr_quant_type);
713 if (status != RET_OK) {
714 MS_LOG(ERROR) << "InsertForwardCastNode failed, cnode name: " << cnode->fullname_with_scope();
715 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
716 return status;
717 }
718 }
719 }
720
721 if (this->param_->fullQuantParam.bias_correction) {
722 MS_LOG(INFO) << "do bias correction";
723 BiasCorrectionStrategy strategy(param_, calibrator_, quant_strategy_, fp32_ms_model_, init_param_.activation_q_min_,
724 init_param_.activation_q_max_);
725 status = strategy.DoBiasCorrection(func_graph);
726 if (status != RET_OK) {
727 MS_LOG(ERROR) << "Do bias correction failed.";
728 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
729 return RET_ERROR;
730 }
731 }
732 return RET_OK;
733 }
734 } // namespace mindspore::lite::quant
735