1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZATION_H_ 18 #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZATION_H_ 19 #include <cstdint> 20 #include <vector> 21 #include <cmath> 22 #include <string> 23 #include "tools/converter/quantizer/quant_params.h" 24 #include "src/common/log_adapter.h" 25 #include "src/common/quant_utils.h" 26 #include "ir/tensor.h" 27 28 namespace mindspore::lite::quant { 29 class MixedBitWeightQuantization { 30 public: 31 explicit MixedBitWeightQuantization(float target_relative_err, float target_search_tolerance = 0.01, 32 int max_search_iters = 100) target_relative_err_(target_relative_err)33 : target_relative_err_(target_relative_err), 34 target_search_tolerance_(target_search_tolerance), 35 max_search_iters_(max_search_iters) {} 36 ~MixedBitWeightQuantization() = default; 37 38 int QuantFilter(const PrimitivePtr &primitive, const AnfNodePtr ¶meter_node, const tensor::TensorPtr &weight, 39 QuantType quant_type, bool use_auto_tune_alg = false); 40 41 private: 42 int DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim, 43 std::vector<schema::QuantParamT> *quant_params, std::vector<int16_t> *quant_datas, 44 const std::string &node_name, bool use_auto_tune_alg = false); 45 float MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim, float scale); 46 47 static MinMax GetMinMax(const float *arr, int arrc); 48 static LayerParam CalculateLayerParams(const float *weights, int element_num); 49 50 int QuantizeByScale(const float *weights, int weightsc, float scale, schema::QuantParamT *quant_params, 51 std::vector<int16_t> *quant_datas); 52 53 BinarySearchResult BinarySearchForQuantizationScale(float *weights, int *shape, int dims, int preferred_dim, 54 int max_iters, float target_err, float rel_tol); 55 56 float GetDx(const float *weights, const int *shape, int dims, const std::string &node_name); 57 58 void CalculateBiasCorrection(const float *weights, int element_num, float scale, float *origin_dequant_datas); 59 60 float CalculateMeanError(std::vector<float> norms2, std::vector<float> dnorms2); 61 62 private: 63 float var_corr_{1}; 64 float mean_corr_{0}; 65 float target_relative_err_; 66 float target_search_tolerance_; 67 int max_search_iters_; 68 }; 69 } // namespace mindspore::lite::quant 70 #endif // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZATION_H_ 71