1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/common/quant_utils.h"
18 #include <functional>
19 #include <map>
20 #include <cmath>
21 #include <cfloat>
22
23 namespace mindspore {
24 namespace lite {
25 // `symmetric` == true -> q range is [-127 , 127];
26 // abs_max = max(abs(r_min),abs(r_max)); r_min = -abs_max and r_max = abs_max.
27 // `symmetric` == false q range is [-128 , 127]. r_min or r_max keep the original value.
28 // `narrow_range` is used to adjust q_min, and symmetric is always true.
CalQuantizationParams(schema::QuantParamT * quant_param,double real_min,double real_max,int num_bits,bool symmetric,bool narrow_range)29 int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, int num_bits,
30 bool symmetric, bool narrow_range) {
31 CHECK_NULL_RETURN(quant_param);
32 int quant_max = QuantMax(num_bits);
33 int quant_min = QuantMin(num_bits, false, narrow_range);
34 return CalQuantizationParams(quant_param, real_min, real_max, num_bits, quant_min, quant_max, symmetric,
35 narrow_range);
36 }
37
EncodeMinMax(float min_value,float max_value,int quant_min,int quant_max,bool symmetric,float * encode_min,float * encode_max)38 void EncodeMinMax(float min_value, float max_value, int quant_min, int quant_max, bool symmetric, float *encode_min,
39 float *encode_max) {
40 // handle case where encode_min_ == encode_max_
41 float epsilon = 1e-10;
42 if (max_value - min_value < epsilon) {
43 MS_LOG(INFO) << min_value << " - " << max_value;
44 }
45 max_value = std::max(max_value, min_value + epsilon);
46 if (symmetric) {
47 auto abs_max = std::max(std::fabs(min_value), std::fabs(max_value));
48 *encode_min = -abs_max;
49 *encode_max = abs_max;
50 } else {
51 *encode_min = min_value;
52 *encode_max = max_value;
53 }
54 // Handling 0
55 // Inputs are strictly positive, set the real min to 0. e.g. input range = [1.0, 5.0] -> [0.0, 5.0]
56 if (*encode_min > 0.0f) {
57 MS_LOG(DEBUG) << "min " << *encode_min << " is bigger then 0, set to 0, this may course low precision";
58 *encode_min = 0.0f;
59 }
60 // Inputs are strictly negative, set the real max to 0. e.g. input range = [-5.0, -1.0] -> [-5.0, 0.0]
61 if (*encode_max < 0.0f) {
62 MS_LOG(DEBUG) << "real_max " << *encode_max << " is smaller than 0, set to 0, this may course low precision";
63 *encode_max = 0.0f;
64 }
65 auto q_range = quant_max - quant_min;
66 MS_ASSERT(quant_max - quant_min > 0);
67 // Inputs are both negative and positive, real_min and real_max are slightly shifted to make the floating point zero
68 // exactly representable. e.g. input range = [-5.1, 5.1] -> [-5.12, 5.08]
69 double step_size = static_cast<double>(*encode_max - *encode_min) / q_range;
70 auto close_0 = std::round(-(*encode_min) / step_size);
71 *encode_min = (0 - close_0) * step_size;
72 *encode_max = (q_range - close_0) * step_size;
73 }
74
CalQuantizationParams(schema::QuantParamT * quant_param,double real_min,double real_max,int num_bits,int quant_min,int quant_max,bool symmetric,bool narrow_range)75 int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, int num_bits,
76 int quant_min, int quant_max, bool symmetric, bool narrow_range) {
77 CHECK_NULL_RETURN(quant_param);
78 float encode_min = real_min;
79 float encode_max = real_max;
80 EncodeMinMax(real_min, real_max, quant_min, quant_max, symmetric, &encode_min, &encode_max);
81 auto q_range = quant_max - quant_min;
82 double scale = (encode_max - encode_min) / q_range;
83 if (fabs(scale) <= 0.0f) {
84 MS_LOG(ERROR) << "divisor 'scale' cannot be 0";
85 return RET_ERROR;
86 }
87 int zero_point = static_cast<int32_t>(std::round(quant_min - encode_min / scale));
88
89 // The zero point should always be in the range of quantized value,
90 // [qmin, qmax].
91 MS_ASSERT(zero_point >= quant_min);
92 MS_ASSERT(zero_point <= quant_max);
93 quant_param->inited = true;
94 quant_param->min = encode_min;
95 quant_param->max = encode_max;
96 quant_param->scale = scale;
97 quant_param->zeroPoint = zero_point;
98 quant_param->varCorr = 1.0;
99 quant_param->meanCorr = 0;
100 quant_param->numBits = num_bits;
101 quant_param->narrowRange = narrow_range;
102
103 return RET_OK;
104 }
105
106 // Get the index of the bucket to which the current data belongs.
GetBucketIndex(const std::vector<int> & dims,int preferred_dim,int data_index)107 int GetBucketIndex(const std::vector<int> &dims, int preferred_dim, int data_index) {
108 int stride = 1;
109 int bucket_count = dims[preferred_dim];
110 for (size_t i = static_cast<size_t>(preferred_dim + 1); i < dims.size(); i++) {
111 stride *= dims[i];
112 }
113 if (stride == 0 || bucket_count == 0) {
114 MS_LOG(ERROR) << "stride or bucket_count is 0.";
115 return 0;
116 }
117 return (data_index / stride) % bucket_count;
118 }
119
GetAllChannelMinMax(const float * raw_datas,size_t elem_count,const std::vector<int> & dims,int preferred_dim,std::map<int,MinMax> * per_channel_min_max)120 void GetAllChannelMinMax(const float *raw_datas, size_t elem_count, const std::vector<int> &dims, int preferred_dim,
121 std::map<int, MinMax> *per_channel_min_max) {
122 MS_ASSERT(raw_datas != nullptr);
123 MS_ASSERT(per_channel_min_max != nullptr);
124 // the key is bucket_index
125 for (int i = 0; i < dims[preferred_dim]; ++i) {
126 per_channel_min_max->insert({i, {FLT_MAX, -FLT_MAX}});
127 }
128 // the first dim.
129 if (preferred_dim == 0) {
130 auto bucket_size = elem_count / dims[preferred_dim];
131 for (int i = 0; i < dims[preferred_dim]; ++i) {
132 auto mim_max = GetFloatMinMaxValue(raw_datas + i * bucket_size, bucket_size);
133 auto iter = per_channel_min_max->find(i);
134 MS_ASSERT(iter != per_channel_min_max->end());
135 iter->second.min = mim_max.first;
136 iter->second.max = mim_max.second;
137 }
138 } else {
139 for (size_t i = 0; i < elem_count; ++i) {
140 auto bucket_index = GetBucketIndex(dims, preferred_dim, i);
141 auto iter = per_channel_min_max->find(bucket_index);
142 MS_ASSERT(iter != per_channel_min_max->end());
143 iter->second.min = std::min(iter->second.min, raw_datas[i]);
144 iter->second.max = std::max(iter->second.max, raw_datas[i]);
145 }
146 }
147 }
148
CalPerChannelGain(size_t bit_num,const std::vector<int> & dims,int preferred_dim)149 int CalPerChannelGain(size_t bit_num, const std::vector<int> &dims, int preferred_dim) {
150 auto elem_count = std::accumulate(std::begin(dims), std::end(dims), 1, std::multiplies<>());
151 const int bits_per_byte = 8;
152 const int quant_param_size = 32;
153 int channels = dims.at(preferred_dim);
154 if (channels < 1) {
155 MS_LOG(ERROR) << "channels must not less 1";
156 return RET_ERROR;
157 }
158 size_t bucket_size = static_cast<size_t>(elem_count / channels);
159 bool do_quant = (quant_param_size * bits_per_byte) / (sizeof(float) * bits_per_byte - bit_num) < bucket_size;
160 if (do_quant) {
161 return RET_OK;
162 } else {
163 MS_LOG(INFO) << "too few elements in a filter, no need to quantize. " << bucket_size;
164 return RET_NO_CHANGE;
165 }
166 }
167
CalWeightQuantBias(const float * raw_datas,size_t elem_count,const std::vector<float> & dequant_datas,std::vector<schema::QuantParamT> * quant_params,const std::vector<int> & dims,int preferred_dim)168 int CalWeightQuantBias(const float *raw_datas, size_t elem_count, const std::vector<float> &dequant_datas,
169 std::vector<schema::QuantParamT> *quant_params, const std::vector<int> &dims,
170 int preferred_dim) {
171 CHECK_NULL_RETURN(raw_datas);
172 CHECK_NULL_RETURN(quant_params);
173 std::map<int, double> total_raws;
174 std::map<int, double> total_dequants;
175 std::map<int, double> average_raws;
176 std::map<int, double> average_dequants;
177 std::map<int, double> var_raws;
178 std::map<int, double> var_dequants;
179 size_t bucket_size = quant_params->size();
180 int bucket_volume = static_cast<size_t>(elem_count / dims[preferred_dim]);
181 // Init Map
182 for (size_t i = 0; i < bucket_size; i++) {
183 total_raws[i] = 0;
184 total_dequants[i] = 0;
185 average_raws[i] = 0;
186 average_dequants[i] = 0;
187 var_raws[i] = 0;
188 var_dequants[i] = 0;
189 }
190 for (size_t data_index = 0; data_index < elem_count; data_index++) {
191 auto data = raw_datas[data_index];
192 auto dequant_data = dequant_datas[data_index];
193 auto bucket_index = GetBucketIndex(dims, preferred_dim, data_index);
194 total_raws[bucket_index] += data;
195 total_dequants[bucket_index] += dequant_data;
196 }
197 for (size_t bucket_index = 0; bucket_index < bucket_size; bucket_index++) {
198 average_raws[bucket_index] = total_raws[bucket_index] / bucket_volume;
199 average_dequants[bucket_index] = total_dequants[bucket_index] / bucket_volume;
200 }
201
202 constexpr int pow_exponent = 2;
203 for (size_t data_index = 0; data_index < elem_count; data_index++) {
204 auto bucket_index = GetBucketIndex(dims, preferred_dim, data_index);
205 var_raws[bucket_index] += std::pow(raw_datas[data_index] - average_raws[bucket_index], pow_exponent);
206 var_dequants[bucket_index] += std::pow(dequant_datas[data_index] - average_dequants[bucket_index], pow_exponent);
207 }
208 for (size_t bucket_index = 0; bucket_index < bucket_size; bucket_index++) {
209 var_raws[bucket_index] = std::sqrt(var_raws[bucket_index] / bucket_volume);
210 var_dequants[bucket_index] = std::sqrt(var_dequants[bucket_index] / bucket_volume);
211 }
212 for (size_t bucket_index = 0; bucket_index < bucket_size; bucket_index++) {
213 quant_params->at(bucket_index).varCorr = 1;
214 if (fabs(var_raws[bucket_index]) > DBL_EPSILON && fabs(var_dequants[bucket_index]) > DBL_EPSILON) {
215 auto temp_var_corr = var_raws[bucket_index] / var_dequants[bucket_index];
216 const int min_var_corr = 0;
217 const int max_var_corr = 10;
218 if (temp_var_corr > min_var_corr && temp_var_corr < max_var_corr) {
219 quant_params->at(bucket_index).varCorr = temp_var_corr;
220 } else {
221 MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
222 }
223 }
224 quant_params->at(bucket_index).meanCorr =
225 average_raws[bucket_index] - average_dequants[bucket_index] * quant_params->at(bucket_index).varCorr;
226 MS_LOG(INFO) << "dims:" << dims << " bucket_index:" << bucket_index
227 << " average_raws[bucket_index]:" << average_raws[bucket_index]
228 << " average_dequants[bucket_index]:" << average_dequants[bucket_index]
229 << " var_raws[bucket_index]:" << var_dequants[bucket_index]
230 << " var_dequants[bucket_index]:" << var_dequants[bucket_index]
231 << " varCorr:" << quant_params->at(bucket_index).varCorr
232 << " meanCorr:" << quant_params->at(bucket_index).meanCorr;
233 }
234 return RET_OK;
235 }
236
CalWeightQuantBiasPerLayer(const float * raw_datas,size_t elem_count,const std::vector<float> & dequant_datas,std::vector<schema::QuantParamT> * quant_params)237 int CalWeightQuantBiasPerLayer(const float *raw_datas, size_t elem_count, const std::vector<float> &dequant_datas,
238 std::vector<schema::QuantParamT> *quant_params) {
239 CHECK_NULL_RETURN(raw_datas);
240 CHECK_NULL_RETURN(quant_params);
241 MS_CHECK_GT(elem_count, 0, RET_ERROR);
242 double total_raws = 0;
243 double total_dequants = 0;
244 double average_raws = 0;
245 double average_dequants = 0;
246 double var_raws = 0;
247 double var_dequants = 0;
248
249 for (size_t data_index = 0; data_index < elem_count; data_index++) {
250 auto data = raw_datas[data_index];
251 auto dequant_data = dequant_datas[data_index];
252 total_raws += data;
253 total_dequants += dequant_data;
254 }
255
256 average_raws = total_raws / elem_count;
257 average_dequants = total_dequants / elem_count;
258
259 constexpr int pow_exponent = 2;
260 for (size_t data_index = 0; data_index < elem_count; data_index++) {
261 var_raws += std::pow(raw_datas[data_index] - average_raws, pow_exponent);
262 var_dequants += std::pow(dequant_datas[data_index] - average_dequants, pow_exponent);
263 }
264
265 var_raws = std::sqrt(var_raws / elem_count);
266 var_dequants = std::sqrt(var_dequants / elem_count);
267
268 quant_params->at(0).varCorr = 1;
269 if (fabs(var_raws) > DBL_EPSILON && fabs(var_dequants) > DBL_EPSILON) {
270 auto temp_var_corr = var_raws / var_dequants;
271 const int min_var_corr = 0;
272 const int max_var_corr = 10;
273 if (temp_var_corr > min_var_corr && temp_var_corr < max_var_corr) {
274 quant_params->at(0).varCorr = temp_var_corr;
275 } else {
276 MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
277 }
278 }
279 quant_params->at(0).meanCorr = average_raws - average_dequants * quant_params->at(0).varCorr;
280 MS_LOG(INFO) << " average_raws:" << average_raws << " average_dequants:" << average_dequants
281 << " var_raws:" << var_dequants << " var_dequants:" << var_dequants
282 << " varCorr:" << quant_params->at(0).varCorr << " meanCorr:" << quant_params->at(0).meanCorr;
283 return RET_OK;
284 }
285 } // namespace lite
286 } // namespace mindspore
287