• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZATION_H_
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZATION_H_
19 #include <cstdint>
20 #include <vector>
21 #include <cmath>
22 #include <string>
23 #include "tools/converter/quantizer/quant_params.h"
24 #include "src/common/log_adapter.h"
25 #include "src/common/quant_utils.h"
26 #include "ir/tensor.h"
27 
28 namespace mindspore::lite::quant {
29 class MixedBitWeightQuantization {
30  public:
31   explicit MixedBitWeightQuantization(float target_relative_err, float target_search_tolerance = 0.01,
32                                       int max_search_iters = 100)
target_relative_err_(target_relative_err)33       : target_relative_err_(target_relative_err),
34         target_search_tolerance_(target_search_tolerance),
35         max_search_iters_(max_search_iters) {}
36   ~MixedBitWeightQuantization() = default;
37 
38   int QuantFilter(const PrimitivePtr &primitive, const AnfNodePtr &parameter_node, const tensor::TensorPtr &weight,
39                   QuantType quant_type, bool use_auto_tune_alg = false);
40 
41  private:
42   int DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
43                      std::vector<schema::QuantParamT> *quant_params, std::vector<int16_t> *quant_datas,
44                      const std::string &node_name, bool use_auto_tune_alg = false);
45   float MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim, float scale);
46 
47   static MinMax GetMinMax(const float *arr, int arrc);
48   static LayerParam CalculateLayerParams(const float *weights, int element_num);
49 
50   int QuantizeByScale(const float *weights, int weightsc, float scale, schema::QuantParamT *quant_params,
51                       std::vector<int16_t> *quant_datas);
52 
53   BinarySearchResult BinarySearchForQuantizationScale(float *weights, int *shape, int dims, int preferred_dim,
54                                                       int max_iters, float target_err, float rel_tol);
55 
56   float GetDx(const float *weights, const int *shape, int dims, const std::string &node_name);
57 
58   void CalculateBiasCorrection(const float *weights, int element_num, float scale, float *origin_dequant_datas);
59 
60   float CalculateMeanError(std::vector<float> norms2, std::vector<float> dnorms2);
61 
62  private:
63   float var_corr_{1};
64   float mean_corr_{0};
65   float target_relative_err_;
66   float target_search_tolerance_;
67   int max_search_iters_;
68 };
69 }  // namespace mindspore::lite::quant
70 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_MIXED_BIT_WEIGHT_QUANTIZATION_H_
71