1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/quantizer/mixed_bit_weight_quantization.h"
18 #include <cmath>
19 #include <cfloat>
20 #include <map>
21 #include <memory>
22 #include "tools/common/statistic_utils.h"
23 #include "tools/converter/quantizer/quantize_util.h"
24
25 namespace mindspore::lite::quant {
26 constexpr float kTwentyFour = 24.0f;
27
CalculateBiasCorrection(const float * weights,int element_num,float scale,float * origin_dequant_datas)28 void MixedBitWeightQuantization::CalculateBiasCorrection(const float *weights, int element_num, float scale,
29 float *origin_dequant_datas) {
30 MS_ASSERT(weights != nullptr);
31 MS_ASSERT(origin_dequant_datas != nullptr);
32 MS_ASSERT(element_num > 0);
33 double average_dequant = 0;
34 double average_raw = 0;
35 const float upround_offset = 0.5;
36 for (int i = 0; i < element_num; i++) {
37 float dequant = scale * (floorf(weights[i] / scale + upround_offset));
38 origin_dequant_datas[i] = dequant;
39 average_raw += weights[i];
40 average_dequant += dequant;
41 }
42
43 // mean
44 average_dequant = average_dequant / element_num;
45 average_raw = average_raw / element_num;
46 // std
47 double variance_dequant = 0;
48 double variance_raw = 0;
49 const int exponent = 2;
50 for (int i = 0; i < element_num; i++) {
51 variance_dequant += std::pow(origin_dequant_datas[i] - average_dequant, exponent);
52 variance_raw += std::pow(weights[i] - average_raw, exponent);
53 }
54 MS_ASSERT(variance_dequant >= 0);
55 MS_ASSERT(variance_raw >= 0);
56 variance_dequant = std::sqrt(variance_dequant / element_num);
57 variance_raw = std::sqrt(variance_raw / element_num);
58 if (fabs(variance_dequant) < DBL_EPSILON) {
59 var_corr_ = 1;
60 } else {
61 var_corr_ = variance_raw / variance_dequant;
62 }
63 mean_corr_ = average_raw - average_dequant * var_corr_;
64 }
65
66 // the error is currently measured per channel.
CalculateMeanError(std::vector<float> norms2,std::vector<float> dnorms2)67 float MixedBitWeightQuantization::CalculateMeanError(std::vector<float> norms2, std::vector<float> dnorms2) {
68 int error_count = 0;
69 float mse_error = 1e-10f;
70 const float soft = 1e-7f;
71 const float tolerance_error = 1.0e-10f;
72 for (size_t i = 0; i < norms2.size(); i++) {
73 if (norms2[i] < tolerance_error) {
74 continue;
75 }
76 error_count += 1;
77 mse_error += sqrtf(dnorms2[i] / norms2[i]);
78 }
79 auto mean_error = mse_error / (error_count + soft);
80 return mean_error;
81 }
82
83 // the `preferred` dim should point to the output channels dimension.
MeasureQuantizationError(float * weights,const int * shape,int dims,int preferred_dim,float scale)84 float MixedBitWeightQuantization::MeasureQuantizationError(float *weights, const int *shape, int dims,
85 int preferred_dim, float scale) {
86 MS_ASSERT(weights != nullptr);
87 MS_ASSERT(shape != nullptr);
88 // Init
89 int element_num = 1;
90 for (int i = 0; i < dims; i++) {
91 element_num *= shape[i];
92 }
93 if (element_num <= 0) {
94 MS_LOG(ERROR) << "Element is less than or equal to 0.";
95 return FLT_MAX;
96 }
97 int bucket_count = shape[preferred_dim];
98 std::vector<float> norms2(bucket_count);
99 std::vector<float> dnorms2(bucket_count);
100 const float init_number = 0.0;
101 for (int i = 0; i < bucket_count; i++) {
102 norms2[i] = init_number;
103 dnorms2[i] = init_number;
104 }
105
106 // Bucketing
107 std::vector<float> origin_dequant_datas(element_num);
108 std::vector<float> corr_dequant_datas(element_num);
109 int bucket_volume = 1;
110 for (int i = preferred_dim; i < dims; i++) {
111 bucket_volume *= shape[i];
112 }
113 MS_ASSERT(bucket_volume != 0);
114 const float upround_offset = 0.5;
115 // Bias Correction
116 CalculateBiasCorrection(weights, element_num, scale, origin_dequant_datas.data());
117 for (int i = 0; i < element_num; i++) {
118 int bucket = (i / bucket_volume) % bucket_count;
119 norms2[bucket] += weights[i] * weights[i];
120 float dequant = var_corr_ * (scale * (floorf(weights[i] / scale + upround_offset))) + mean_corr_;
121 corr_dequant_datas[i] = dequant;
122 float d = weights[i] - dequant;
123 dnorms2[bucket] += d * d;
124 }
125 auto mean_error = CalculateMeanError(norms2, dnorms2);
126 return mean_error;
127 }
128
CalculateLayerParams(const float * weights,int element_num)129 LayerParam MixedBitWeightQuantization::CalculateLayerParams(const float *weights, int element_num) {
130 MS_ASSERT(weights != nullptr);
131 float temp_norm_tot = 0.0;
132 for (int i = 0; i < element_num; i++) {
133 temp_norm_tot += weights[i] * weights[i];
134 }
135
136 LayerParam ret = {std::sqrt(1.0f / temp_norm_tot), GetMinMax(weights, element_num)};
137 return ret;
138 }
139
GetMinMax(const float * arr,int arrc)140 MinMax MixedBitWeightQuantization::GetMinMax(const float *arr, int arrc) {
141 MS_ASSERT(arr != nullptr);
142 MinMax min_max = {INFINITY, -INFINITY};
143 for (int i = 0; i < arrc; i++)
144 if (arr[i] > min_max.max)
145 min_max.max = arr[i];
146 else if (arr[i] < min_max.min)
147 min_max.min = arr[i];
148 return min_max;
149 }
150
BinarySearchForQuantizationScale(float * weights,int * shape,int dims,int preferred_dim,int max_iters,float target_err,float rel_tol)151 BinarySearchResult MixedBitWeightQuantization::BinarySearchForQuantizationScale(float *weights, int *shape, int dims,
152 int preferred_dim, int max_iters,
153 float target_err, float rel_tol) {
154 MS_ASSERT(weights != nullptr);
155 MS_ASSERT(shape != nullptr);
156 int element_num = 1;
157 for (int i = 0; i < dims; i++) {
158 element_num *= shape[i];
159 }
160 MinMax mm = GetMinMax(weights, element_num);
161 if (mm.max < mm.min + 1.0e-5) {
162 return {0, static_cast<float>(std::fabs(mm.max) + 1.0e-5)};
163 }
164 // start a binary search
165 float curr_scale = (mm.max - mm.min) * target_err;
166 float right_hs_dx = curr_scale * kBinarySearchStep;
167 while (MeasureQuantizationError(weights, shape, dims, preferred_dim, right_hs_dx) < target_err) {
168 right_hs_dx *= kBinarySearchStep;
169 }
170 float left_hs_dx = curr_scale / kBinarySearchStep;
171 while (MeasureQuantizationError(weights, shape, dims, preferred_dim, left_hs_dx) > target_err) {
172 left_hs_dx /= kBinarySearchStep;
173 }
174 int iter_count = 0;
175 BinarySearchResult res = {0, curr_scale};
176 while (true) {
177 float curr_err = MeasureQuantizationError(weights, shape, dims, preferred_dim, res.scale);
178 if (std::fabs(curr_err - target_err) / target_err < rel_tol) {
179 return res;
180 }
181 if (iter_count > max_iters) {
182 if (curr_err < target_err) {
183 res.status = RET_OK;
184 } else {
185 res.status = RET_ERROR;
186 }
187 return res;
188 }
189 if (curr_err > target_err)
190 right_hs_dx = res.scale;
191 else
192 left_hs_dx = res.scale;
193 res.scale = (left_hs_dx + right_hs_dx) / kBinarySearchStep;
194 iter_count += 1;
195 }
196 }
197
GetDx(const float * weights,const int * shape,int dims,const std::string & node_name)198 float MixedBitWeightQuantization::GetDx(const float *weights, const int *shape, int dims,
199 const std::string &node_name) {
200 MS_ASSERT(weights != nullptr);
201 MS_ASSERT(shape != nullptr);
202 static std::map<std::string, LayerParam> param_map;
203
204 int element_num = 1;
205 for (int i = 0; i < dims; i++) {
206 element_num *= shape[i];
207 }
208
209 LayerParam params;
210 auto params_it = param_map.find(node_name);
211 if (params_it == param_map.end()) {
212 params = CalculateLayerParams(weights, element_num);
213 param_map.insert({node_name, params});
214 } else {
215 params = params_it->second;
216 }
217 return (target_relative_err_ + target_search_tolerance_ * std::sqrt(kTwentyFour / element_num)) / params.inv_norm;
218 }
219
DoQuantization(float * weights,std::vector<int64_t> shape,int preferred_dim,std::vector<schema::QuantParamT> * quant_params,std::vector<int16_t> * quant_datas,const std::string & node_name,bool use_auto_tune_alg)220 int MixedBitWeightQuantization::DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
221 std::vector<schema::QuantParamT> *quant_params,
222 std::vector<int16_t> *quant_datas, const std::string &node_name,
223 bool use_auto_tune_alg) {
224 CHECK_NULL_RETURN(weights);
225 CHECK_NULL_RETURN(quant_params);
226 CHECK_NULL_RETURN(quant_datas);
227 int weight_count = 1;
228 int dims = shape.size();
229 int input_shape[4] = {0, 0, 0, 0};
230 MS_ASSERT(dims <= input_shape.size());
231 for (int i = 0; i < dims; i++) {
232 weight_count *= shape[i];
233 input_shape[i] = shape[i];
234 }
235
236 float scale = 1.0;
237 if (use_auto_tune_alg) {
238 scale = GetDx(weights, input_shape, dims, node_name);
239 } else {
240 BinarySearchResult br = BinarySearchForQuantizationScale(
241 weights, input_shape, dims, preferred_dim, max_search_iters_, target_relative_err_, target_search_tolerance_);
242 if (br.status != RET_OK) {
243 MS_LOG(WARNING) << "this layer reached max iters.";
244 return RET_NO_CHANGE;
245 }
246 scale = br.scale;
247 }
248
249 schema::QuantParamT quant_param;
250 int qr = QuantizeByScale(weights, weight_count, scale, &quant_param, quant_datas);
251 if (qr != RET_OK) {
252 MS_LOG(ERROR) << "quant failed.";
253 return RET_ERROR;
254 }
255 quant_params->push_back(quant_param);
256 return RET_OK;
257 }
258
QuantizeByScale(const float * weights,int weightsc,float scale,schema::QuantParamT * quant_params,std::vector<int16_t> * quant_datas)259 int MixedBitWeightQuantization::QuantizeByScale(const float *weights, int weightsc, float scale,
260 schema::QuantParamT *quant_params, std::vector<int16_t> *quant_datas) {
261 CHECK_NULL_RETURN(weights);
262 CHECK_NULL_RETURN(quant_params);
263 CHECK_NULL_RETURN(quant_datas);
264 MS_CHECK_GE(static_cast<int>(quant_datas->size()), weightsc, RET_ERROR);
265 const float upround_offset = 0.5;
266 for (int i = 0; i < weightsc; i++) {
267 auto q = static_cast<int>(floorf(weights[i] / scale + upround_offset));
268 quant_datas->at(i) = q;
269 }
270 quant_params->meanCorr = mean_corr_;
271 quant_params->varCorr = var_corr_;
272 quant_params->scale = scale;
273 quant_params->zeroPoint = 0;
274 quant_params->numBits = 0;
275 return RET_OK;
276 }
277
QuantFilter(const PrimitivePtr & primitive,const AnfNodePtr & parameter_node,const tensor::TensorPtr & weight,QuantType quant_type,bool use_auto_tune_alg)278 int MixedBitWeightQuantization::QuantFilter(const PrimitivePtr &primitive, const AnfNodePtr ¶meter_node,
279 const tensor::TensorPtr &weight, QuantType quant_type,
280 bool use_auto_tune_alg) {
281 CHECK_NULL_RETURN(primitive);
282 CHECK_NULL_RETURN(weight);
283 std::vector<schema::QuantParamT> quant_params;
284 int elem_count = weight->DataSize();
285 auto *raw_data = static_cast<float *>(weight->data_c());
286 if (raw_data == nullptr) {
287 MS_LOG(ERROR) << "rawDatas is nullptr";
288 return RET_ERROR;
289 }
290
291 std::vector<int16_t> quant_data(elem_count);
292 auto ret = DoQuantization(static_cast<float *>(weight->data_c()), weight->shape_c(), 0, &quant_params, &quant_data,
293 parameter_node->fullname_with_scope(), use_auto_tune_alg);
294 if (ret != RET_OK) {
295 return ret;
296 }
297 ret = UpdateTensorDataAndSize(parameter_node, weight, quant_data.data(), quant_data.size() * sizeof(int16_t),
298 kNumberTypeInt16);
299 if (ret != RET_OK) {
300 MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
301 return RET_ERROR;
302 }
303 auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(quant_params);
304 CHECK_NULL_RETURN(quantization_ptr);
305 weight->set_quant_param(std::vector<std::shared_ptr<mindspore::QuantizationParam>>{quantization_ptr});
306 auto quant_type_value = MakeValue(static_cast<int>(quant_type));
307 MS_CHECK_TRUE_MSG(quant_type_value != nullptr, RET_ERROR, "quant_type is nullptr.");
308 primitive->AddAttr(quant::kQuantType, quant_type_value);
309 return ret;
310 }
311 } // namespace mindspore::lite::quant
312