1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/quantizer/mixed_bit_weight_quantizer.h"
18 #include <cmath>
19
20 namespace mindspore::lite::quant {
21 // the error is currently measured per channel.
22 // it could be measured per layer but it would be less good.
23 // the `preferred` dim should point to the output channels dimension.
MeasureQuantizationError(float * weights,const int * shape,int dims,int preferred_dim,float scale)24 float MixedBitWeightQuantizer::MeasureQuantizationError(float *weights, const int *shape, int dims, int preferred_dim,
25 float scale) {
26 MS_ASSERT(weights != nullptr);
27 MS_ASSERT(shape != nullptr);
28 int numel = 1;
29 for (int i = 0; i < dims; i++) {
30 numel *= shape[i];
31 }
32 int bucket_count = shape[preferred_dim];
33 std::vector<float> norms2(bucket_count);
34 std::vector<float> dnorms2(bucket_count);
35 for (int i = 0; i < bucket_count; i++) {
36 norms2[i] = 0.0;
37 dnorms2[i] = 0.0;
38 }
39 double average_dequant = 0;
40 double average_raw = 0;
41 std::vector<float> dequant_datas(numel);
42 int bucket_volume = 1;
43 for (int i = preferred_dim; i < dims; i++) {
44 bucket_volume *= shape[i];
45 }
46 for (int i = 0; i < numel; i++) {
47 float dequant = scale * (floorf(weights[i] / scale + 0.5));
48 dequant_datas[i] = dequant;
49 average_raw += weights[i];
50 average_dequant += dequant;
51 }
52 // mean
53 average_dequant = average_dequant / numel;
54 average_raw = average_raw / numel;
55 // std
56 double variance_dequant = 0;
57 double variance_raw = 0;
58 for (int i = 0; i < numel; i++) {
59 variance_dequant += std::pow(dequant_datas[i] - average_dequant, 2);
60 variance_raw += std::pow(weights[i] - average_raw, 2);
61 }
62 variance_dequant = std::sqrt(variance_dequant / numel);
63 variance_raw = std::sqrt(variance_raw / numel);
64 var_corr_ = variance_raw / variance_dequant;
65 mean_corr_ = average_raw - average_dequant * var_corr_;
66
67 for (int i = 0; i < numel; i++) {
68 int bucket = (i / bucket_volume) % bucket_count;
69 norms2[bucket] += weights[i] * weights[i];
70 float dequant = var_corr_ * (scale * (floorf(weights[i] / scale + 0.5))) + mean_corr_;
71 float d = weights[i] - dequant;
72 dnorms2[bucket] += d * d;
73 }
74
75 int c = 0;
76 float t = 1e-10;
77 for (int i = 0; i < bucket_count; i++) {
78 if (norms2[i] < 1.0e-10) continue;
79 c += 1;
80 t += sqrtf(dnorms2[i] / norms2[i]);
81 }
82 return t / (c + 1e-7);
83 }
84
GetMinMax(const float * arr,int arrc)85 MinMax MixedBitWeightQuantizer::GetMinMax(const float *arr, int arrc) {
86 MS_ASSERT(arr != nullptr);
87 MinMax min_max = {INFINITY, -INFINITY};
88 for (int i = 0; i < arrc; i++)
89 if (arr[i] > min_max.max)
90 min_max.max = arr[i];
91 else if (arr[i] < min_max.min)
92 min_max.min = arr[i];
93 return min_max;
94 }
95
BinarySearchForQuantizationScale(float * weights,int * shape,int dims,int preferred_dim,int max_iters,float target_err,float rel_tol)96 BinarySearchResult MixedBitWeightQuantizer::BinarySearchForQuantizationScale(float *weights, int *shape, int dims,
97 int preferred_dim, int max_iters,
98 float target_err, float rel_tol) {
99 MS_ASSERT(weights != nullptr);
100 MS_ASSERT(shape != nullptr);
101 int element_num = 1;
102 for (int i = 0; i < dims; i++) {
103 element_num *= shape[i];
104 }
105 MinMax mm = GetMinMax(weights, element_num);
106 if (mm.max < mm.min + 1.0e-5) {
107 return {0, static_cast<float>(std::fabs(mm.max) + 1.0e-5)};
108 }
109 // start a binary search
110 float curr_scale = (mm.max - mm.min) * target_err;
111 float right_hs_dx = curr_scale * 2.0;
112 while (MeasureQuantizationError(weights, shape, dims, preferred_dim, right_hs_dx) < target_err) {
113 right_hs_dx *= 2.0;
114 }
115 float left_hs_dx = curr_scale / 2.0;
116 while (MeasureQuantizationError(weights, shape, dims, preferred_dim, left_hs_dx) > target_err) {
117 left_hs_dx /= 2.0;
118 }
119 int iter_count = 0;
120 BinarySearchResult res = {0, curr_scale};
121 while (true) {
122 float curr_err = MeasureQuantizationError(weights, shape, dims, preferred_dim, res.scale);
123 if (std::fabs(curr_err - target_err) / target_err < rel_tol) {
124 return res;
125 }
126 if (iter_count > max_iters) {
127 res.status = 1;
128 return res;
129 }
130 if (curr_err > target_err)
131 right_hs_dx = res.scale;
132 else
133 left_hs_dx = res.scale;
134 res.scale = (left_hs_dx + right_hs_dx) / 2.0;
135 iter_count += 1;
136 }
137 }
138
DoQuantization(float * weights,std::vector<int64_t> shape,int preferred_dim,std::vector<schema::QuantParamT> * quant_params,std::vector<int16_t> * quant_datas)139 int MixedBitWeightQuantizer::DoQuantization(float *weights, std::vector<int64_t> shape, int preferred_dim,
140 std::vector<schema::QuantParamT> *quant_params,
141 std::vector<int16_t> *quant_datas) {
142 MS_ASSERT(weights != nullptr);
143 int weight_count = 1;
144 int dims = shape.size();
145 int input_shape[4] = {0, 0, 0, 0};
146 for (int i = 0; i < dims; i++) {
147 weight_count *= shape[i];
148 input_shape[i] = shape[i];
149 }
150
151 BinarySearchResult br = BinarySearchForQuantizationScale(weights, input_shape, dims, preferred_dim, max_search_iters_,
152 target_relative_err_, target_search_tolerance_);
153 if (br.status != RET_OK) {
154 MS_LOG(ERROR) << "reached_max_iters";
155 return RET_ERROR;
156 }
157 schema::QuantParamT quant_param;
158 int qr = QuantizeByScale(weights, weight_count, br.scale, &quant_param, quant_datas);
159 if (qr != RET_OK) {
160 MS_LOG(ERROR) << "quant failed.";
161 return RET_ERROR;
162 }
163
164 // It is used to calculate the Shannon entropy.
165 quant_params->push_back(quant_param);
166 return RET_OK;
167 }
168
QuantizeByScale(const float * weights,int weightsc,float scale,schema::QuantParamT * quant_params,std::vector<int16_t> * quant_datas)169 int MixedBitWeightQuantizer::QuantizeByScale(const float *weights, int weightsc, float scale,
170 schema::QuantParamT *quant_params, std::vector<int16_t> *quant_datas) {
171 MS_ASSERT(weights != nullptr);
172 for (int i = 0; i < weightsc; i++) {
173 auto q = static_cast<int>(floorf(weights[i] / scale + 0.5));
174 quant_datas->at(i) = q;
175 }
176 quant_params->meanCorr = mean_corr_;
177 quant_params->varCorr = var_corr_;
178 quant_params->scale = scale;
179 quant_params->zeroPoint = 0;
180 quant_params->numBits = 0;
181 quant_params->inited = true;
182 return RET_OK;
183 }
184 } // namespace mindspore::lite::quant
185