1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIXED_BIT_WEIGHT_QUANTIZATION_H 18 #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIXED_BIT_WEIGHT_QUANTIZATION_H 19 20 #include <vector> 21 #include <functional> 22 #include <map> 23 #include <memory> 24 #include "ir/tensor.h" 25 #include "schema/inner/model_generated.h" 26 #include "src/common/log_adapter.h" 27 #include "src/common/quant_utils.h" 28 #include "tools/converter/quantizer/quant_params.h" 29 #include "tools/converter/quantizer/quantize_util.h" 30 #include "mindspore/core/ir/quantization_param.h" 31 32 namespace mindspore::lite::quant { 33 class FixedBitWeightQuantization { 34 public: 35 FixedBitWeightQuantization() = default; 36 37 ~FixedBitWeightQuantization() = default; 38 39 int QuantFilter(const AnfNodePtr ¶meter_node, const tensor::TensorPtr &weight, const PrimitivePtr &primitive, 40 quant::QuantType quant_type, int quant_max, int quant_min, size_t bit_num, 41 WeightQuantType weight_quant_type, TypeId quant_data_type, int preferred_dim, bool symmetric = false, 42 bool narrow_range = false, bool bias_correction = true); 43 44 int QuantBias(const ParameterPtr &weight, const ParameterPtr &bias, 45 const std::vector<schema::QuantParamT> &active_quant_params); 46 47 private: 48 int ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const std::vector<double> &input_scales, 49 const float *raw_datas, std::vector<schema::QuantParamT> *weight_quant_params, 50 const tensor::TensorPtr &weight, std::vector<schema::QuantParamT> *bias_quant_params, 51 std::vector<int32_t> *quant_datas); 52 53 template <typename T> 54 int FixedBitQuantFilter(const AnfNodePtr ¶meter_node, const tensor::TensorPtr &weight, 55 const PrimitivePtr &primitive, quant::QuantType quant_type, int quant_max, int quant_min, 56 size_t bit_num, WeightQuantType weight_quant_type, TypeId quant_data_type, int preferred_dim, 57 bool symmetric = false, bool narrow_range = false, bool bias_correction = true) { 58 size_t elem_count = weight->DataSize(); 59 auto *raw_data = static_cast<float *>(weight->data_c()); 60 if (raw_data == nullptr) { 61 MS_LOG(ERROR) << "rawDatas is nullptr"; 62 return RET_ERROR; 63 } 64 std::vector<T> quant_data(elem_count); 65 auto status = FixedBitStatisticsFilter<T>(weight, quant_type, quant_max, quant_min, bit_num, weight_quant_type, 66 preferred_dim, &quant_data, symmetric, narrow_range, bias_correction); 67 if (status == RET_NO_CHANGE) { 68 return status; 69 } else if (status != RET_OK) { 70 MS_LOG(ERROR) << "FixedBitStatisticsFilter failed : " << status; 71 return status; 72 } 73 status = UpdateTensorDataAndSize(parameter_node, weight, quant_data.data(), quant_data.size() * sizeof(T), 74 quant_data_type); 75 if (status != RET_OK) { 76 MS_LOG(ERROR) << "UpdateTensorDataAndSize error"; 77 return RET_ERROR; 78 } 79 auto quant_type_value = MakeValue(static_cast<int>(quant_type)); 80 MS_CHECK_TRUE_MSG(quant_type_value != nullptr, RET_ERROR, "quant_type is nullptr."); 81 primitive->AddAttr(quant::kQuantType, quant_type_value); 82 return RET_OK; 83 } 84 85 template <typename T> 86 int FixedBitStatisticsFilter(const tensor::TensorPtr &weight, quant::QuantType quant_type, int quant_max, 87 int quant_min, size_t bit_num, WeightQuantType weight_quant_type, int preferred_dim, 88 std::vector<T> *quant_data, bool symmetric = false, bool narrow_range = false, 89 bool bias_correction = true) { 90 MS_ASSERT(weight != nullptr); 91 auto dims = weight->shape(); 92 if (weight_quant_type == FIXED_BIT_PER_CHANNEL) { 93 if (dims.size() <= 1) { 94 MS_LOG(WARNING) << "dims is " << dims.size() << " can not per_channel"; 95 weight_quant_type = FIXED_BIT_PER_LAYER; 96 } 97 } 98 if (weight->data_type_c() != kNumberTypeFloat32) { 99 MS_LOG(ERROR) << "data type is not Float32."; 100 return RET_ERROR; 101 } 102 103 std::vector<schema::QuantParamT> quant_params; 104 int ret = RET_OK; 105 bool cal_gain = (quant_type == QUANT_WEIGHT) && bias_correction ? true : false; 106 if (weight_quant_type == FIXED_BIT_PER_CHANNEL) { 107 ret = DoPerChannelQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max, 108 quant_min, bit_num, quant_data, ConvertShapeVectorToInt32(dims), preferred_dim, 109 cal_gain, symmetric, narrow_range); 110 if (ret == RET_NO_CHANGE) { 111 return ret; 112 } else if (ret != RET_OK) { 113 MS_LOG(ERROR) << "Do per channel quant failed."; 114 return ret; 115 } 116 } else if (weight_quant_type == FIXED_BIT_PER_LAYER) { 117 ret = DoPerLayerQuant<T>(static_cast<float *>(weight->data_c()), weight->DataSize(), &quant_params, quant_max, 118 quant_min, bit_num, quant_data, symmetric, narrow_range, cal_gain); 119 if (ret != RET_OK) { 120 MS_LOG(ERROR) << "Do per layer quant failed."; 121 return ret; 122 } 123 } else { 124 MS_LOG(ERROR) << "Unsupported weight quant type:" << weight_quant_type; 125 return RET_ERROR; 126 } 127 auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(quant_params); 128 CHECK_NULL_RETURN(quantization_ptr); 129 weight->set_quant_param(std::vector<std::shared_ptr<mindspore::QuantizationParam>>{quantization_ptr}); 130 return ret; 131 } 132 }; 133 } // namespace mindspore::lite::quant 134 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FIXED_BIT_WEIGHT_QUANTIZATION_H 135