• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/common/quant_utils.h"
18 #include <functional>
19 #include <map>
20 #include <cmath>
21 #include <cfloat>
22 
23 namespace mindspore {
24 namespace lite {
25 // `symmetric` == true -> q range is [-127 , 127];
26 //  abs_max = max(abs(r_min),abs(r_max)); r_min = -abs_max and r_max = abs_max.
27 //  `symmetric` == false q range is [-128 , 127]. r_min or r_max keep the original value.
28 // `narrow_range` is used to adjust q_min, and symmetric is always true.
CalQuantizationParams(schema::QuantParamT * quant_param,double real_min,double real_max,int num_bits,bool symmetric,bool narrow_range)29 int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, int num_bits,
30                           bool symmetric, bool narrow_range) {
31   CHECK_NULL_RETURN(quant_param);
32   int quant_max = QuantMax(num_bits);
33   int quant_min = QuantMin(num_bits, false, narrow_range);
34   return CalQuantizationParams(quant_param, real_min, real_max, num_bits, quant_min, quant_max, symmetric,
35                                narrow_range);
36 }
37 
EncodeMinMax(float min_value,float max_value,int quant_min,int quant_max,bool symmetric,float * encode_min,float * encode_max)38 void EncodeMinMax(float min_value, float max_value, int quant_min, int quant_max, bool symmetric, float *encode_min,
39                   float *encode_max) {
40   // handle case where encode_min_ == encode_max_
41   float epsilon = 1e-10;
42   if (max_value - min_value < epsilon) {
43     MS_LOG(INFO) << min_value << " - " << max_value;
44   }
45   max_value = std::max(max_value, min_value + epsilon);
46   if (symmetric) {
47     auto abs_max = std::max(std::fabs(min_value), std::fabs(max_value));
48     *encode_min = -abs_max;
49     *encode_max = abs_max;
50   } else {
51     *encode_min = min_value;
52     *encode_max = max_value;
53   }
54   // Handling 0
55   // Inputs are strictly positive, set the real min to 0. e.g. input range = [1.0, 5.0] -> [0.0, 5.0]
56   if (*encode_min > 0.0f) {
57     MS_LOG(DEBUG) << "min " << *encode_min << " is bigger then 0, set to 0, this may course low precision";
58     *encode_min = 0.0f;
59   }
60   // Inputs are strictly negative, set the real max to 0. e.g. input range = [-5.0, -1.0] -> [-5.0, 0.0]
61   if (*encode_max < 0.0f) {
62     MS_LOG(DEBUG) << "real_max " << *encode_max << " is smaller than 0, set to 0, this may course low precision";
63     *encode_max = 0.0f;
64   }
65   auto q_range = quant_max - quant_min;
66   MS_ASSERT(quant_max - quant_min > 0);
67   // Inputs are both negative and positive, real_min and real_max are slightly shifted to make the floating point zero
68   // exactly representable. e.g. input range = [-5.1, 5.1] -> [-5.12, 5.08]
69   double step_size = static_cast<double>(*encode_max - *encode_min) / q_range;
70   auto close_0 = std::round(-(*encode_min) / step_size);
71   *encode_min = (0 - close_0) * step_size;
72   *encode_max = (q_range - close_0) * step_size;
73 }
74 
CalQuantizationParams(schema::QuantParamT * quant_param,double real_min,double real_max,int num_bits,int quant_min,int quant_max,bool symmetric,bool narrow_range)75 int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, int num_bits,
76                           int quant_min, int quant_max, bool symmetric, bool narrow_range) {
77   CHECK_NULL_RETURN(quant_param);
78   float encode_min = real_min;
79   float encode_max = real_max;
80   EncodeMinMax(real_min, real_max, quant_min, quant_max, symmetric, &encode_min, &encode_max);
81   auto q_range = quant_max - quant_min;
82   double scale = (encode_max - encode_min) / q_range;
83   if (fabs(scale) <= 0.0f) {
84     MS_LOG(ERROR) << "divisor 'scale' cannot be 0";
85     return RET_ERROR;
86   }
87   int zero_point = static_cast<int32_t>(std::round(quant_min - encode_min / scale));
88 
89   // The zero point should always be in the range of quantized value,
90   // [qmin, qmax].
91   MS_ASSERT(zero_point >= quant_min);
92   MS_ASSERT(zero_point <= quant_max);
93   quant_param->inited = true;
94   quant_param->min = encode_min;
95   quant_param->max = encode_max;
96   quant_param->scale = scale;
97   quant_param->zeroPoint = zero_point;
98   quant_param->varCorr = 1.0;
99   quant_param->meanCorr = 0;
100   quant_param->numBits = num_bits;
101   quant_param->narrowRange = narrow_range;
102 
103   return RET_OK;
104 }
105 
106 // Get the index of the bucket to which the current data belongs.
GetBucketIndex(const std::vector<int> & dims,int preferred_dim,int data_index)107 int GetBucketIndex(const std::vector<int> &dims, int preferred_dim, int data_index) {
108   int stride = 1;
109   int bucket_count = dims[preferred_dim];
110   for (size_t i = static_cast<size_t>(preferred_dim + 1); i < dims.size(); i++) {
111     stride *= dims[i];
112   }
113   if (stride == 0 || bucket_count == 0) {
114     MS_LOG(ERROR) << "stride or bucket_count is 0.";
115     return 0;
116   }
117   return (data_index / stride) % bucket_count;
118 }
119 
GetAllChannelMinMax(const float * raw_datas,size_t elem_count,const std::vector<int> & dims,int preferred_dim,std::map<int,MinMax> * per_channel_min_max)120 void GetAllChannelMinMax(const float *raw_datas, size_t elem_count, const std::vector<int> &dims, int preferred_dim,
121                          std::map<int, MinMax> *per_channel_min_max) {
122   MS_ASSERT(raw_datas != nullptr);
123   MS_ASSERT(per_channel_min_max != nullptr);
124   // the key is bucket_index
125   for (int i = 0; i < dims[preferred_dim]; ++i) {
126     per_channel_min_max->insert({i, {FLT_MAX, -FLT_MAX}});
127   }
128   // the first dim.
129   if (preferred_dim == 0) {
130     auto bucket_size = elem_count / dims[preferred_dim];
131     for (int i = 0; i < dims[preferred_dim]; ++i) {
132       auto mim_max = GetFloatMinMaxValue(raw_datas + i * bucket_size, bucket_size);
133       auto iter = per_channel_min_max->find(i);
134       MS_ASSERT(iter != per_channel_min_max->end());
135       iter->second.min = mim_max.first;
136       iter->second.max = mim_max.second;
137     }
138   } else {
139     for (size_t i = 0; i < elem_count; ++i) {
140       auto bucket_index = GetBucketIndex(dims, preferred_dim, i);
141       auto iter = per_channel_min_max->find(bucket_index);
142       MS_ASSERT(iter != per_channel_min_max->end());
143       iter->second.min = std::min(iter->second.min, raw_datas[i]);
144       iter->second.max = std::max(iter->second.max, raw_datas[i]);
145     }
146   }
147 }
148 
CalPerChannelGain(size_t bit_num,const std::vector<int> & dims,int preferred_dim)149 int CalPerChannelGain(size_t bit_num, const std::vector<int> &dims, int preferred_dim) {
150   auto elem_count = std::accumulate(std::begin(dims), std::end(dims), 1, std::multiplies<>());
151   const int bits_per_byte = 8;
152   const int quant_param_size = 32;
153   int channels = dims.at(preferred_dim);
154   if (channels < 1) {
155     MS_LOG(ERROR) << "channels must not less 1";
156     return RET_ERROR;
157   }
158   size_t bucket_size = static_cast<size_t>(elem_count / channels);
159   bool do_quant = (quant_param_size * bits_per_byte) / (sizeof(float) * bits_per_byte - bit_num) < bucket_size;
160   if (do_quant) {
161     return RET_OK;
162   } else {
163     MS_LOG(INFO) << "too few elements in a filter, no need to quantize. " << bucket_size;
164     return RET_NO_CHANGE;
165   }
166 }
167 
CalWeightQuantBias(const float * raw_datas,size_t elem_count,const std::vector<float> & dequant_datas,std::vector<schema::QuantParamT> * quant_params,const std::vector<int> & dims,int preferred_dim)168 int CalWeightQuantBias(const float *raw_datas, size_t elem_count, const std::vector<float> &dequant_datas,
169                        std::vector<schema::QuantParamT> *quant_params, const std::vector<int> &dims,
170                        int preferred_dim) {
171   CHECK_NULL_RETURN(raw_datas);
172   CHECK_NULL_RETURN(quant_params);
173   std::map<int, double> total_raws;
174   std::map<int, double> total_dequants;
175   std::map<int, double> average_raws;
176   std::map<int, double> average_dequants;
177   std::map<int, double> var_raws;
178   std::map<int, double> var_dequants;
179   size_t bucket_size = quant_params->size();
180   int bucket_volume = static_cast<size_t>(elem_count / dims[preferred_dim]);
181   // Init Map
182   for (size_t i = 0; i < bucket_size; i++) {
183     total_raws[i] = 0;
184     total_dequants[i] = 0;
185     average_raws[i] = 0;
186     average_dequants[i] = 0;
187     var_raws[i] = 0;
188     var_dequants[i] = 0;
189   }
190   for (size_t data_index = 0; data_index < elem_count; data_index++) {
191     auto data = raw_datas[data_index];
192     auto dequant_data = dequant_datas[data_index];
193     auto bucket_index = GetBucketIndex(dims, preferred_dim, data_index);
194     total_raws[bucket_index] += data;
195     total_dequants[bucket_index] += dequant_data;
196   }
197   for (size_t bucket_index = 0; bucket_index < bucket_size; bucket_index++) {
198     average_raws[bucket_index] = total_raws[bucket_index] / bucket_volume;
199     average_dequants[bucket_index] = total_dequants[bucket_index] / bucket_volume;
200   }
201 
202   constexpr int pow_exponent = 2;
203   for (size_t data_index = 0; data_index < elem_count; data_index++) {
204     auto bucket_index = GetBucketIndex(dims, preferred_dim, data_index);
205     var_raws[bucket_index] += std::pow(raw_datas[data_index] - average_raws[bucket_index], pow_exponent);
206     var_dequants[bucket_index] += std::pow(dequant_datas[data_index] - average_dequants[bucket_index], pow_exponent);
207   }
208   for (size_t bucket_index = 0; bucket_index < bucket_size; bucket_index++) {
209     var_raws[bucket_index] = std::sqrt(var_raws[bucket_index] / bucket_volume);
210     var_dequants[bucket_index] = std::sqrt(var_dequants[bucket_index] / bucket_volume);
211   }
212   for (size_t bucket_index = 0; bucket_index < bucket_size; bucket_index++) {
213     quant_params->at(bucket_index).varCorr = 1;
214     if (fabs(var_raws[bucket_index]) > DBL_EPSILON && fabs(var_dequants[bucket_index]) > DBL_EPSILON) {
215       auto temp_var_corr = var_raws[bucket_index] / var_dequants[bucket_index];
216       const int min_var_corr = 0;
217       const int max_var_corr = 10;
218       if (temp_var_corr > min_var_corr && temp_var_corr < max_var_corr) {
219         quant_params->at(bucket_index).varCorr = temp_var_corr;
220       } else {
221         MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
222       }
223     }
224     quant_params->at(bucket_index).meanCorr =
225       average_raws[bucket_index] - average_dequants[bucket_index] * quant_params->at(bucket_index).varCorr;
226     MS_LOG(INFO) << "dims:" << dims << " bucket_index:" << bucket_index
227                  << " average_raws[bucket_index]:" << average_raws[bucket_index]
228                  << " average_dequants[bucket_index]:" << average_dequants[bucket_index]
229                  << " var_raws[bucket_index]:" << var_dequants[bucket_index]
230                  << " var_dequants[bucket_index]:" << var_dequants[bucket_index]
231                  << " varCorr:" << quant_params->at(bucket_index).varCorr
232                  << " meanCorr:" << quant_params->at(bucket_index).meanCorr;
233   }
234   return RET_OK;
235 }
236 
CalWeightQuantBiasPerLayer(const float * raw_datas,size_t elem_count,const std::vector<float> & dequant_datas,std::vector<schema::QuantParamT> * quant_params)237 int CalWeightQuantBiasPerLayer(const float *raw_datas, size_t elem_count, const std::vector<float> &dequant_datas,
238                                std::vector<schema::QuantParamT> *quant_params) {
239   CHECK_NULL_RETURN(raw_datas);
240   CHECK_NULL_RETURN(quant_params);
241   MS_CHECK_GT(elem_count, 0, RET_ERROR);
242   double total_raws = 0;
243   double total_dequants = 0;
244   double average_raws = 0;
245   double average_dequants = 0;
246   double var_raws = 0;
247   double var_dequants = 0;
248 
249   for (size_t data_index = 0; data_index < elem_count; data_index++) {
250     auto data = raw_datas[data_index];
251     auto dequant_data = dequant_datas[data_index];
252     total_raws += data;
253     total_dequants += dequant_data;
254   }
255 
256   average_raws = total_raws / elem_count;
257   average_dequants = total_dequants / elem_count;
258 
259   constexpr int pow_exponent = 2;
260   for (size_t data_index = 0; data_index < elem_count; data_index++) {
261     var_raws += std::pow(raw_datas[data_index] - average_raws, pow_exponent);
262     var_dequants += std::pow(dequant_datas[data_index] - average_dequants, pow_exponent);
263   }
264 
265   var_raws = std::sqrt(var_raws / elem_count);
266   var_dequants = std::sqrt(var_dequants / elem_count);
267 
268   quant_params->at(0).varCorr = 1;
269   if (fabs(var_raws) > DBL_EPSILON && fabs(var_dequants) > DBL_EPSILON) {
270     auto temp_var_corr = var_raws / var_dequants;
271     const int min_var_corr = 0;
272     const int max_var_corr = 10;
273     if (temp_var_corr > min_var_corr && temp_var_corr < max_var_corr) {
274       quant_params->at(0).varCorr = temp_var_corr;
275     } else {
276       MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
277     }
278   }
279   quant_params->at(0).meanCorr = average_raws - average_dequants * quant_params->at(0).varCorr;
280   MS_LOG(INFO) << " average_raws:" << average_raws << " average_dequants:" << average_dequants
281                << " var_raws:" << var_dequants << " var_dequants:" << var_dequants
282                << " varCorr:" << quant_params->at(0).varCorr << " meanCorr:" << quant_params->at(0).meanCorr;
283   return RET_OK;
284 }
285 }  // namespace lite
286 }  // namespace mindspore
287