1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
18 #define MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
19
20 #include <float.h>
21 #include <cmath>
22 #include <climits>
23 #include <limits>
24 #include <algorithm>
25 #include <vector>
26 #include "include/errorcode.h"
27 #include "src/common/log_adapter.h"
28 #include "ir/dtype/type_id.h"
29
30 namespace mindspore {
31
32 namespace schema {
33 struct QuantParamT;
34 }
35
36 namespace lite {
37 const int RET_QUANT_CONTINUE = 2;
38 static constexpr double SCALE_THREASHOLD = 1e-38;
39
40 static constexpr int kPerTensor = 1;
41
QuantMax(int bits,TypeId type)42 inline int QuantMax(int bits, TypeId type) {
43 if (type == kNumberTypeInt8) {
44 return (1 << (bits - 1)) - 1;
45 } else if (type == kNumberTypeUInt8) {
46 return (1 << bits) - 1;
47 }
48 return 0;
49 }
50
QuantMin(int bits,TypeId type)51 inline int QuantMin(int bits, TypeId type) {
52 if (type == kNumberTypeInt8) {
53 return -(1 << (bits - 1));
54 }
55 return 0;
56 }
57
58 STATUS GetMaxMinPerChannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas,
59 bool channel_at_first, float *desired_max, float *desired_min);
60
61 STATUS CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, bool narrow_range,
62 int quant_max, int quant_min, int num_bits);
63
64 template <typename T>
QuantizeData(const float originData,const schema::QuantParamT * quantParam)65 T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
66 MS_ASSERT(quantParam != nullptr);
67 MS_ASSERT(quantParam->inited);
68 const auto scale = quantParam->scale;
69 const auto zeroPoint = quantParam->zeroPoint;
70 const auto numBit = quantParam->numBits;
71 const auto narrowRange = quantParam->narrowRange;
72 const int32_t quantMax = (1 << (unsigned int)(numBit - 1)) - 1;
73 const int32_t quantMin = -1 * (1 << (unsigned int)(numBit - 1)) + (narrowRange ? 1 : 0);
74 const double maxLimit = static_cast<float>(quantMax - zeroPoint) * scale;
75 const double minLimit = static_cast<float>(quantMin - zeroPoint) * scale;
76
77 return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
78 double tmp;
79 if (originData > maxLimit) {
80 tmp = maxLimit;
81 } else if (originData < minLimit) {
82 tmp = minLimit;
83 } else {
84 tmp = originData;
85 }
86 auto quantData = static_cast<T>(std::round(zeroPoint + tmp / scale));
87 return quantData;
88 }();
89 }
90
91 template <typename T>
QuantizeData(float originData,const schema::QuantParamT * quantParam,int quant_max,int quant_min)92 T QuantizeData(float originData, const schema::QuantParamT *quantParam, int quant_max, int quant_min) {
93 MS_ASSERT(quantParam != nullptr);
94 MS_ASSERT(quantParam->inited);
95 const auto scale = quantParam->scale;
96 const int zeroPoint = quantParam->zeroPoint;
97 const int maxLimit = quant_max;
98 const int minLimit = quant_min;
99
100 if (scale <= SCALE_THREASHOLD) {
101 return 0;
102 }
103
104 return [maxLimit, minLimit, zeroPoint, scale, originData] {
105 auto quant_data = std::round(originData / scale + zeroPoint);
106 if (quant_data > maxLimit) {
107 quant_data = maxLimit;
108 } else if (quant_data < minLimit) {
109 quant_data = minLimit;
110 }
111 return static_cast<T>(quant_data);
112 }();
113 }
114
115 template <typename T>
DoPerLayerQuant(const float * raw_datas,size_t elem_count,std::vector<schema::QuantParamT> * quant_params,const int & quant_max,const int & quant_min,const size_t & bit_num,const bool & k_means,std::vector<T> * quant_datas)116 STATUS DoPerLayerQuant(const float *raw_datas, size_t elem_count, std::vector<schema::QuantParamT> *quant_params,
117 const int &quant_max, const int &quant_min, const size_t &bit_num, const bool &k_means,
118 std::vector<T> *quant_datas) {
119 float min = FLT_MAX;
120 float max = -FLT_MIN;
121 for (uint32_t i = 0; i < elem_count; i++) {
122 min = std::min(min, raw_datas[i]);
123 max = std::max(max, raw_datas[i]);
124 }
125
126 schema::QuantParamT quant_param;
127 if (!k_means) {
128 STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num);
129 if (status != RET_OK) {
130 MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
131 return status;
132 }
133 }
134 quant_params->emplace_back(quant_param);
135 // update data and datatype
136 for (uint32_t i = 0; i < elem_count; i++) {
137 float raw_data = raw_datas[i];
138 if (!k_means) {
139 auto quant_data = QuantizeData<T>(raw_data, &quant_param, quant_max, quant_min);
140 (*quant_datas)[i] = quant_data;
141 }
142 }
143 return RET_OK;
144 }
145
146 template <typename T>
147 STATUS DoPerChannelQuant(const float *raw_datas, size_t elem_count, const schema::QuantType &quant_type,
148 std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min,
149 const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas, int channels,
150 bool channel_at_first = true) {
151 static const int quant_param_size = 32 * 8;
152 std::vector<float> dequant_datas(quant_datas->size());
153 if (channels <= 0) {
154 MS_LOG(ERROR) << "channels must be greater than 0";
155 return RET_ERROR;
156 }
157 size_t one_filter_size = elem_count / channels;
158 bool do_quant = quant_param_size / (sizeof(float) * 8 - bit_num) < one_filter_size;
159 if (!do_quant && quant_type == schema::QuantType_QUANT_WEIGHT) {
160 MS_LOG(INFO) << "too few elements in a filter, no need to quantize. " << one_filter_size;
161 return RET_QUANT_CONTINUE;
162 }
163 for (int i = 0; i < channels; i++) {
164 float min = FLT_MAX;
165 float max = -FLT_MAX;
166 STATUS status =
167 GetMaxMinPerChannel(channels, one_filter_size, i, elem_count, raw_datas, channel_at_first, &max, &min);
168 if (status != RET_OK) {
169 MS_LOG(ERROR) << "GetMaxMinPerChannel failed" << status;
170 return status;
171 }
172 schema::QuantParamT quant_param;
173 status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num);
174 if (status != RET_OK) {
175 MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
176 return status;
177 }
178 // do quantization
179 double average_dequant = 0;
180 double average_raw = 0;
181 for (uint32_t j = 0; j < one_filter_size; j++) {
182 auto index = j + i * one_filter_size;
183 if (!channel_at_first) {
184 index = j * channels + i;
185 }
186 MS_ASSERT(index < elem_count);
187 float raw_data = raw_datas[index];
188 auto quant_data = QuantizeData<T>(raw_data, &quant_param, quant_max, quant_min);
189 (*quant_datas)[index] = quant_data;
190
191 if (quant_type == schema::QuantType_QUANT_WEIGHT) {
192 float dequant_data = quant_param.scale * (quant_data - quant_param.zeroPoint);
193 dequant_datas[index] = dequant_data;
194 average_dequant += dequant_data;
195 average_raw += raw_data;
196 }
197 }
198 if (quant_type == schema::QuantType_QUANT_WEIGHT && !k_means) {
199 // mean
200 average_dequant = average_dequant / one_filter_size;
201 average_raw = average_raw / one_filter_size;
202 // std
203 double variance_dequant = 0;
204 double variance_raw = 0;
205 for (uint32_t j = 0; j < one_filter_size; j++) {
206 auto index = j + i * one_filter_size;
207 if (!channel_at_first) {
208 index = j * channels + i;
209 }
210 MS_ASSERT(index < elem_count);
211 variance_dequant += std::pow(dequant_datas[index] - average_dequant, 2);
212 variance_raw += std::pow(raw_datas[index] - average_raw, 2);
213 }
214 variance_dequant = std::sqrt(variance_dequant / one_filter_size);
215 variance_raw = std::sqrt(variance_raw / one_filter_size);
216 quant_param.varCorr = 1;
217 if (variance_raw != 0 && variance_dequant != 0) {
218 auto temp_var_corr = variance_raw / variance_dequant;
219 if (temp_var_corr > 0 && temp_var_corr < 10) {
220 quant_param.varCorr = temp_var_corr;
221 } else {
222 MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
223 }
224 }
225 quant_param.meanCorr = average_raw - average_dequant * quant_param.varCorr;
226 }
227 quant_params->emplace_back(quant_param);
228 }
229 return RET_OK;
230 }
231 } // namespace lite
232 } // namespace mindspore
233
234 #endif // MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
235