• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/quantizer/mixed_bit_weight_quantization.h"
18 #include <cmath>
19 #include <cfloat>
20 #include <map>
21 #include <memory>
22 #include "tools/common/statistic_utils.h"
23 #include "tools/converter/quantizer/quantize_util.h"
24 
25 namespace mindspore::lite::quant {
26 constexpr float kTwentyFour = 24.0f;
27 
CalculateBiasCorrection(const float * weights,int element_num,float scale,float * origin_dequant_datas)28 void MixedBitWeightQuantization::CalculateBiasCorrection(const float *weights, int element_num, float scale,
29                                                          float *origin_dequant_datas) {
30   MS_ASSERT(weights != nullptr);
31   MS_ASSERT(origin_dequant_datas != nullptr);
32   MS_ASSERT(element_num > 0);
33   double average_dequant = 0;
34   double average_raw = 0;
35   const float upround_offset = 0.5;
36   for (int i = 0; i < element_num; i++) {
37     float dequant = scale * (floorf(weights[i] / scale + upround_offset));
38     origin_dequant_datas[i] = dequant;
39     average_raw += weights[i];
40     average_dequant += dequant;
41   }
42 
43   // mean
44   average_dequant = average_dequant / element_num;
45   average_raw = average_raw / element_num;
46   // std
47   double variance_dequant = 0;
48   double variance_raw = 0;
49   const int exponent = 2;
50   for (int i = 0; i < element_num; i++) {
51     variance_dequant += std::pow(origin_dequant_datas[i] - average_dequant, exponent);
52     variance_raw += std::pow(weights[i] - average_raw, exponent);
53   }
54   MS_ASSERT(variance_dequant >= 0);
55   MS_ASSERT(variance_raw >= 0);
56   variance_dequant = std::sqrt(variance_dequant / element_num);
57   variance_raw = std::sqrt(variance_raw / element_num);
58   if (fabs(variance_dequant) < DBL_EPSILON) {
59     var_corr_ = 1;
60   } else {
61     var_corr_ = variance_raw / variance_dequant;
62   }
63   mean_corr_ = average_raw - average_dequant * var_corr_;
64 }
65 
66 // the error is currently measured per channel.
CalculateMeanError(std::vector<float> norms2,std::vector<float> dnorms2)67 float MixedBitWeightQuantization::CalculateMeanError(std::vector<float> norms2, std::vector<float> dnorms2) {
68   int error_count = 0;
69   float mse_error = 1e-10f;
70   const float soft = 1e-7f;
71   const float tolerance_error = 1.0e-10f;
72   for (size_t i = 0; i < norms2.size(); i++) {
73     if (norms2[i] < tolerance_error) {
74       continue;
75     }
76     error_count += 1;
77     mse_error += sqrtf(dnorms2[i] / norms2[i]);
78   }
79   auto mean_error = mse_error / (error_count + soft);
80   return mean_error;
81 }
82 
83 // the `preferred` dim should point to the output channels dimension.
MeasureQuantizationError(float * weights,const int * shape,int dims,int preferred_dim,float scale)84 float MixedBitWeightQuantization::MeasureQuantizationError(float *weights, const int *shape, int dims,
85                                                            int preferred_dim, float scale) {
86   MS_ASSERT(weights != nullptr);
87   MS_ASSERT(shape != nullptr);
88   // Init
89   int element_num = 1;
90   for (int i = 0; i < dims; i++) {
91     element_num *= shape[i];
92   }
93   if (element_num <= 0) {
94     MS_LOG(ERROR) << "Element is less than or equal to 0.";
95     return FLT_MAX;
96   }
97   int bucket_count = shape[preferred_dim];
98   std::vector<float> norms2(bucket_count);
99   std::vector<float> dnorms2(bucket_count);
100   const float init_number = 0.0;
101   for (int i = 0; i < bucket_count; i++) {
102     norms2[i] = init_number;
103     dnorms2[i] = init_number;
104   }
105 
106   // Bucketing
107   std::vector<float> origin_dequant_datas(element_num);
108   std::vector<float> corr_dequant_datas(element_num);
109   int bucket_volume = 1;
110   for (int i = preferred_dim; i < dims; i++) {
111     bucket_volume *= shape[i];
112   }
113   MS_ASSERT(bucket_volume != 0);
114   const float upround_offset = 0.5;
115   // Bias Correction
116   CalculateBiasCorrection(weights, element_num, scale, origin_dequant_datas.data());
117   for (int i = 0; i < element_num; i++) {
118     int bucket = (i / bucket_volume) % bucket_count;
119     norms2[bucket] += weights[i] * weights[i];
120     float dequant = var_corr_ * (scale * (floorf(weights[i] / scale + upround_offset))) + mean_corr_;
121     corr_dequant_datas[i] = dequant;
122     float d = weights[i] - dequant;
123     dnorms2[bucket] += d * d;
124   }
125   auto mean_error = CalculateMeanError(norms2, dnorms2);
126   return mean_error;
127 }
128 
CalculateLayerParams(const float * weights,int element_num)129 LayerParam MixedBitWeightQuantization::CalculateLayerParams(const float *weights, int element_num) {
130   MS_ASSERT(weights != nullptr);
131   float temp_norm_tot = 0.0;
132   for (int i = 0; i < element_num; i++) {
133     temp_norm_tot += weights[i] * weights[i];
134   }
135 
136   LayerParam ret = {std::sqrt(1.0f / temp_norm_tot), GetMinMax(weights, element_num)};
137   return ret;
138 }
139 
GetMinMax(const float * arr,int arrc)140 MinMax MixedBitWeightQuantization::GetMinMax(const float *arr, int arrc) {
141   MS_ASSERT(arr != nullptr);
142   MinMax min_max = {INFINITY, -INFINITY};
143   for (int i = 0; i < arrc; i++)
144     if (arr[i] > min_max.max)
145       min_max.max = arr[i];
146     else if (arr[i] < min_max.min)
147       min_max.min = arr[i];
148   return min_max;
149 }
150 
BinarySearchForQuantizationScale(float * weights,int * shape,int dims,int preferred_dim,int max_iters,float target_err,float rel_tol)151 BinarySearchResult MixedBitWeightQuantization::BinarySearchForQuantizationScale(float *weights, int *shape, int dims,
152                                                                                 int preferred_dim, int max_iters,
153                                                                                 float target_err, float rel_tol) {
154   MS_ASSERT(weights != nullptr);
155   MS_ASSERT(shape != nullptr);
156   int element_num = 1;
157   for (int i = 0; i < dims; i++) {
158     element_num *= shape[i];
159   }
160   MinMax mm = GetMinMax(weights, element_num);
161   if (mm.max < mm.min + 1.0e-5) {
162     return {0, static_cast<float>(std::fabs(mm.max) + 1.0e-5)};
163   }
164   // start a binary search
165   float curr_scale = (mm.max - mm.min) * target_err;
166   float right_hs_dx = curr_scale * kBinarySearchStep;
167   while (MeasureQuantizationError(weights, shape, dims, preferred_dim, right_hs_dx) < target_err) {
168     right_hs_dx *= kBinarySearchStep;
169   }
170   float left_hs_dx = curr_scale / kBinarySearchStep;
171   while (MeasureQuantizationError(weights, shape, dims, preferred_dim, left_hs_dx) > target_err) {
172     left_hs_dx /= kBinarySearchStep;
173   }
174   int iter_count = 0;
175   BinarySearchResult res = {0, curr_scale};
176   while (true) {
177     float curr_err = MeasureQuantizationError(weights, shape, dims, preferred_dim, res.scale);
178     if (std::fabs(curr_err - target_err) / target_err < rel_tol) {
179       return res;
180     }
181     if (iter_count > max_iters) {
182       if (curr_err < target_err) {
183         res.status = RET_OK;
184       } else {
185         res.status = RET_ERROR;
186       }
187       return res;
188     }
189     if (curr_err > target_err)
190       right_hs_dx = res.scale;
191     else
192       left_hs_dx = res.scale;
193     res.scale = (left_hs_dx + right_hs_dx) / kBinarySearchStep;
194     iter_count += 1;
195   }
196 }
197 
GetDx(const float * weights,const int * shape,int dims,const std::string & node_name)198 float MixedBitWeightQuantization::GetDx(const float *weights, const int *shape, int dims,
199                                         const std::string &node_name) {
200   MS_ASSERT(weights != nullptr);
201   MS_ASSERT(shape != nullptr);
202   static std::map<std::string, LayerParam> param_map;
203 
204   int element_num = 1;
205   for (int i = 0; i < dims; i++) {
206     element_num *= shape[i];
207   }
208 
209   LayerParam params;
210   auto params_it = param_map.find(node_name);
211   if (params_it == param_map.end()) {
212     params = CalculateLayerParams(weights, element_num);
213     param_map.insert({node_name, params});
214   } else {
215     params = params_it->second;
216   }
217   return (target_relative_err_ + target_search_tolerance_ * std::sqrt(kTwentyFour / element_num)) / params.inv_norm;
218 }
219 
DoQuantization(float * weights,std::vector<int64_t> shape,int preferred_dim,std::vector<schema::QuantParamT> * quant_params,std::vector<int16_t> * quant_datas,const std::string & node_name,bool use_auto_tune_alg)220 int MixedBitWeightQuantization::DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
221                                                std::vector<schema::QuantParamT> *quant_params,
222                                                std::vector<int16_t> *quant_datas, const std::string &node_name,
223                                                bool use_auto_tune_alg) {
224   CHECK_NULL_RETURN(weights);
225   CHECK_NULL_RETURN(quant_params);
226   CHECK_NULL_RETURN(quant_datas);
227   int weight_count = 1;
228   int dims = shape.size();
229   int input_shape[4] = {0, 0, 0, 0};
230   MS_ASSERT(dims <= input_shape.size());
231   for (int i = 0; i < dims; i++) {
232     weight_count *= shape[i];
233     input_shape[i] = shape[i];
234   }
235 
236   float scale = 1.0;
237   if (use_auto_tune_alg) {
238     scale = GetDx(weights, input_shape, dims, node_name);
239   } else {
240     BinarySearchResult br = BinarySearchForQuantizationScale(
241       weights, input_shape, dims, preferred_dim, max_search_iters_, target_relative_err_, target_search_tolerance_);
242     if (br.status != RET_OK) {
243       MS_LOG(WARNING) << "this layer reached max iters.";
244       return RET_NO_CHANGE;
245     }
246     scale = br.scale;
247   }
248 
249   schema::QuantParamT quant_param;
250   int qr = QuantizeByScale(weights, weight_count, scale, &quant_param, quant_datas);
251   if (qr != RET_OK) {
252     MS_LOG(ERROR) << "quant failed.";
253     return RET_ERROR;
254   }
255   quant_params->push_back(quant_param);
256   return RET_OK;
257 }
258 
QuantizeByScale(const float * weights,int weightsc,float scale,schema::QuantParamT * quant_params,std::vector<int16_t> * quant_datas)259 int MixedBitWeightQuantization::QuantizeByScale(const float *weights, int weightsc, float scale,
260                                                 schema::QuantParamT *quant_params, std::vector<int16_t> *quant_datas) {
261   CHECK_NULL_RETURN(weights);
262   CHECK_NULL_RETURN(quant_params);
263   CHECK_NULL_RETURN(quant_datas);
264   MS_CHECK_GE(static_cast<int>(quant_datas->size()), weightsc, RET_ERROR);
265   const float upround_offset = 0.5;
266   for (int i = 0; i < weightsc; i++) {
267     auto q = static_cast<int>(floorf(weights[i] / scale + upround_offset));
268     quant_datas->at(i) = q;
269   }
270   quant_params->meanCorr = mean_corr_;
271   quant_params->varCorr = var_corr_;
272   quant_params->scale = scale;
273   quant_params->zeroPoint = 0;
274   quant_params->numBits = 0;
275   return RET_OK;
276 }
277 
QuantFilter(const PrimitivePtr & primitive,const AnfNodePtr & parameter_node,const tensor::TensorPtr & weight,QuantType quant_type,bool use_auto_tune_alg)278 int MixedBitWeightQuantization::QuantFilter(const PrimitivePtr &primitive, const AnfNodePtr &parameter_node,
279                                             const tensor::TensorPtr &weight, QuantType quant_type,
280                                             bool use_auto_tune_alg) {
281   CHECK_NULL_RETURN(primitive);
282   CHECK_NULL_RETURN(weight);
283   std::vector<schema::QuantParamT> quant_params;
284   int elem_count = weight->DataSize();
285   auto *raw_data = static_cast<float *>(weight->data_c());
286   if (raw_data == nullptr) {
287     MS_LOG(ERROR) << "rawDatas is nullptr";
288     return RET_ERROR;
289   }
290 
291   std::vector<int16_t> quant_data(elem_count);
292   auto ret = DoQuantization(static_cast<float *>(weight->data_c()), weight->shape_c(), 0, &quant_params, &quant_data,
293                             parameter_node->fullname_with_scope(), use_auto_tune_alg);
294   if (ret != RET_OK) {
295     return ret;
296   }
297   ret = UpdateTensorDataAndSize(parameter_node, weight, quant_data.data(), quant_data.size() * sizeof(int16_t),
298                                 kNumberTypeInt16);
299   if (ret != RET_OK) {
300     MS_LOG(ERROR) << "UpdateTensorDataAndSize error";
301     return RET_ERROR;
302   }
303   auto quantization_ptr = quant::ConvertQuantParamTToQuantizationParam(quant_params);
304   CHECK_NULL_RETURN(quantization_ptr);
305   weight->set_quant_param(std::vector<std::shared_ptr<mindspore::QuantizationParam>>{quantization_ptr});
306   auto quant_type_value = MakeValue(static_cast<int>(quant_type));
307   MS_CHECK_TRUE_MSG(quant_type_value != nullptr, RET_ERROR, "quant_type is nullptr.");
308   primitive->AddAttr(quant::kQuantType, quant_type_value);
309   return ret;
310 }
311 }  // namespace mindspore::lite::quant
312