• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
18 #include <cmath>
19 
20 namespace mindspore::lite::quant {
21 // the error is currently measured per channel.
22 // it could be measured per layer but it would be less good.
23 // the `preferred` dim should point to the output channels dimension.
MeasureQuantizationError(float * weights,const int * shape,int dims,int preferred_dim,float scale)24 float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim,
25                                                         float scale) {
26   MS_ASSERT(weights != nullptr);
27   MS_ASSERT(shape != nullptr);
28   int numel = 1;
29   for (int i = 0; i < dims; i++) {
30     numel *= shape[i];
31   }
32   int bucket_count = shape[preferred_dim];
33   std::vector<float> norms2(bucket_count);
34   std::vector<float> dnorms2(bucket_count);
35   for (int i = 0; i < bucket_count; i++) {
36     norms2[i] = 0.0;
37     dnorms2[i] = 0.0;
38   }
39   double average_dequant = 0;
40   double average_raw = 0;
41   std::vector<float> dequant_datas(numel);
42   int bucket_volume = 1;
43   for (int i = preferred_dim; i < dims; i++) {
44     bucket_volume *= shape[i];
45   }
46   for (int i = 0; i < numel; i++) {
47     float dequant = scale * (floorf(weights[i] / scale + 0.5));
48     dequant_datas[i] = dequant;
49     average_raw += weights[i];
50     average_dequant += dequant;
51   }
52   // mean
53   average_dequant = average_dequant / numel;
54   average_raw = average_raw / numel;
55   // std
56   double variance_dequant = 0;
57   double variance_raw = 0;
58   for (int i = 0; i < numel; i++) {
59     variance_dequant += std::pow(dequant_datas[i] - average_dequant, 2);
60     variance_raw += std::pow(weights[i] - average_raw, 2);
61   }
62   variance_dequant = std::sqrt(variance_dequant / numel);
63   variance_raw = std::sqrt(variance_raw / numel);
64   var_corr_ = variance_raw / variance_dequant;
65   mean_corr_ = average_raw - average_dequant * var_corr_;
66 
67   for (int i = 0; i < numel; i++) {
68     int bucket = (i / bucket_volume) % bucket_count;
69     norms2[bucket] += weights[i] * weights[i];
70     float dequant = var_corr_ * (scale * (floorf(weights[i] / scale + 0.5))) + mean_corr_;
71     float d = weights[i] - dequant;
72     dnorms2[bucket] += d * d;
73   }
74 
75   int c = 0;
76   float t = 1e-10;
77   for (int i = 0; i < bucket_count; i++) {
78     if (norms2[i] < 1.0e-10) continue;
79     c += 1;
80     t += sqrtf(dnorms2[i] / norms2[i]);
81   }
82   return t / (c + 1e-7);
83 }
84 
GetMinMax(const float * arr,int arrc)85 MinMax MixedBitWeightQuantizer::GetMinMax(const float *arr, int arrc) {
86   MS_ASSERT(arr != nullptr);
87   MinMax min_max = {INFINITY, -INFINITY};
88   for (int i = 0; i < arrc; i++)
89     if (arr[i] > min_max.max)
90       min_max.max = arr[i];
91     else if (arr[i] < min_max.min)
92       min_max.min = arr[i];
93   return min_max;
94 }
95 
BinarySearchForQuantizationScale(float * weights,int * shape,int dims,int preferred_dim,int max_iters,float target_err,float rel_tol)96 BinarySearchResult MixedBitWeightQuantizer::BinarySearchForQuantizationScale(float *weights, int *shape, int dims,
97                                                                              int preferred_dim, int max_iters,
98                                                                              float target_err, float rel_tol) {
99   MS_ASSERT(weights != nullptr);
100   MS_ASSERT(shape != nullptr);
101   int element_num = 1;
102   for (int i = 0; i < dims; i++) {
103     element_num *= shape[i];
104   }
105   MinMax mm = GetMinMax(weights, element_num);
106   if (mm.max < mm.min + 1.0e-5) {
107     return {0, static_cast<float>(std::fabs(mm.max) + 1.0e-5)};
108   }
109   // start a binary search
110   float curr_scale = (mm.max - mm.min) * target_err;
111   float right_hs_dx = curr_scale * 2.0;
112   while (MeasureQuantizationError(weights, shape, dims, preferred_dim, right_hs_dx) < target_err) {
113     right_hs_dx *= 2.0;
114   }
115   float left_hs_dx = curr_scale / 2.0;
116   while (MeasureQuantizationError(weights, shape, dims, preferred_dim, left_hs_dx) > target_err) {
117     left_hs_dx /= 2.0;
118   }
119   int iter_count = 0;
120   BinarySearchResult res = {0, curr_scale};
121   while (true) {
122     float curr_err = MeasureQuantizationError(weights, shape, dims, preferred_dim, res.scale);
123     if (std::fabs(curr_err - target_err) / target_err < rel_tol) {
124       return res;
125     }
126     if (iter_count > max_iters) {
127       res.status = 1;
128       return res;
129     }
130     if (curr_err > target_err)
131       right_hs_dx = res.scale;
132     else
133       left_hs_dx = res.scale;
134     res.scale = (left_hs_dx + right_hs_dx) / 2.0;
135     iter_count += 1;
136   }
137 }
138 
DoQuantization(float * weights,std::vector<int64_t> shape,int preferred_dim,std::vector<schema::QuantParamT> * quant_params,std::vector<int16_t> * quant_datas)139 int MixedBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
140                                             std::vector<schema::QuantParamT> *quant_params,
141                                             std::vector<int16_t> *quant_datas) {
142   MS_ASSERT(weights != nullptr);
143   int weight_count = 1;
144   int dims = shape.size();
145   int input_shape[4] = {0, 0, 0, 0};
146   for (int i = 0; i < dims; i++) {
147     weight_count *= shape[i];
148     input_shape[i] = shape[i];
149   }
150 
151   BinarySearchResult br = BinarySearchForQuantizationScale(weights, input_shape, dims, preferred_dim, max_search_iters_,
152                                                            target_relative_err_, target_search_tolerance_);
153   if (br.status != RET_OK) {
154     MS_LOG(ERROR) << "reached_max_iters";
155     return RET_ERROR;
156   }
157   schema::QuantParamT quant_param;
158   int qr = QuantizeByScale(weights, weight_count, br.scale, &quant_param, quant_datas);
159   if (qr != RET_OK) {
160     MS_LOG(ERROR) << "quant failed.";
161     return RET_ERROR;
162   }
163 
164   // It is used to calculate the Shannon entropy.
165   quant_params->push_back(quant_param);
166   return RET_OK;
167 }
168 
QuantizeByScale(const float * weights,int weightsc,float scale,schema::QuantParamT * quant_params,std::vector<int16_t> * quant_datas)169 int MixedBitWeightQuantizer::QuantizeByScale(const float *weights, int weightsc, float scale,
170                                              schema::QuantParamT *quant_params, std::vector<int16_t> *quant_datas) {
171   MS_ASSERT(weights != nullptr);
172   for (int i = 0; i < weightsc; i++) {
173     auto q = static_cast<int>(floorf(weights[i] / scale + 0.5));
174     quant_datas->at(i) = q;
175   }
176   quant_params->meanCorr = mean_corr_;
177   quant_params->varCorr = var_corr_;
178   quant_params->scale = scale;
179   quant_params->zeroPoint = 0;
180   quant_params->numBits = 0;
181   quant_params->inited = true;
182   return RET_OK;
183 }
184 }  // namespace mindspore::lite::quant
185