1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
18 #define MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
19
20 #include <cfloat>
21 #include <cmath>
22 #include <climits>
23 #include <limits>
24 #include <algorithm>
25 #include <vector>
26 #include <numeric>
27 #include <functional>
28 #include <map>
29 #include "include/errorcode.h"
30 #include "src/common/log_adapter.h"
31 #include "src/common/log_util.h"
32 #include "ir/dtype/type_id.h"
33 #include "schema/inner/model_generated.h"
34 #include "tools/common/statistic_utils.h"
35
36 namespace mindspore {
37 namespace schema {
38 struct QuantParamT;
39 }
40
41 namespace lite {
42 typedef struct {
43 float min;
44 float max;
45 } MinMax;
46
47 static constexpr double SCALE_THREASHOLD = 1e-38;
48 static constexpr int kPerTensor = 1;
49
50 inline int QuantMax(int bits, bool is_unsigned = false) {
51 if (!is_unsigned) {
52 return (1 << static_cast<unsigned int>(bits - 1)) - 1;
53 } else {
54 return (1 << static_cast<unsigned int>(bits)) - 1;
55 }
56 }
57
58 inline int QuantMin(int bits, bool is_unsigned = false, bool is_narrow = true) {
59 if (!is_unsigned) {
60 return -(1 << static_cast<unsigned int>(bits - 1)) + (is_narrow ? 1 : 0);
61 } else {
62 return 0;
63 }
64 }
65
66 int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, int num_bits,
67 int quant_min, int quant_max, bool symmetric, bool narrow_range = false);
68
69 int CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, int num_bits,
70 bool symmetric, bool narrow_range = false);
71 int CalWeightQuantBiasPerLayer(const float *raw_datas, size_t elem_count, const std::vector<float> &dequant_datas,
72 std::vector<schema::QuantParamT> *quant_params);
73
74 void EncodeMinMax(float min_value, float max_value, int quant_min, int quant_max, bool symmetric, float *encode_min,
75 float *encode_max);
76 template <typename T>
QuantizeData(float origin_data,const schema::QuantParamT * quant_param,int quant_max,int quant_min)77 T QuantizeData(float origin_data, const schema::QuantParamT *quant_param, int quant_max, int quant_min) {
78 MS_CHECK_TRUE_MSG(quant_param != nullptr, 0, "quant_param is nullptr.");
79 const auto scale = quant_param->scale;
80 const int zero_point = quant_param->zeroPoint;
81 if (scale <= SCALE_THREASHOLD) {
82 return 0;
83 }
84 return [quant_max, quant_min, zero_point, scale, origin_data] {
85 auto quant_data = std::round(origin_data / scale + zero_point);
86 if (quant_data > quant_max) {
87 quant_data = quant_max;
88 } else if (quant_data < quant_min) {
89 quant_data = quant_min;
90 }
91 return static_cast<T>(quant_data);
92 }();
93 }
94
95 template <typename T>
QuantizeData(const float origin_data,const schema::QuantParamT * quant_param)96 T QuantizeData(const float origin_data, const schema::QuantParamT *quant_param) {
97 MS_ASSERT(quant_param != nullptr);
98 MS_ASSERT(quant_param->inited);
99 const auto num_bit = quant_param->numBits;
100 const auto narrow_range = quant_param->narrowRange;
101 const int32_t quant_max = QuantMax(num_bit, false);
102 const int32_t quant_min = QuantMin(num_bit, false, narrow_range);
103 return QuantizeData<T>(origin_data, quant_param, quant_max, quant_min);
104 }
105
106 template <typename T>
107 int DoPerLayerQuant(const float *raw_datas, size_t elem_count, std::vector<schema::QuantParamT> *quant_params,
108 const int &quant_max, const int &quant_min, const size_t &bit_num, std::vector<T> *quant_datas,
109 bool symmetric = false, bool narrow_range = false, bool cal_gain = false) {
110 auto min_max = GetFloatMinMaxValue(raw_datas, elem_count);
111 schema::QuantParamT quant_param;
112 int status = CalQuantizationParams(&quant_param, min_max.first, min_max.second, bit_num, quant_min, quant_max,
113 symmetric, narrow_range);
114 if (status != RET_OK) {
115 MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
116 return status;
117 }
118 quant_params->emplace_back(quant_param);
119 // update data and datatype
120 for (size_t i = 0; i < elem_count; i++) {
121 float raw_data = raw_datas[i];
122 auto quant_data = QuantizeData<T>(raw_data, &quant_param, quant_max, quant_min);
123 (*quant_datas)[i] = quant_data;
124 }
125 if (cal_gain) {
126 std::vector<float> dequant_datas(quant_datas->size());
127 for (size_t i = 0; i < elem_count; i++) {
128 dequant_datas.at(i) = quant_param.scale * (quant_datas->at(i) - quant_param.zeroPoint);
129 }
130 auto ret = CalWeightQuantBiasPerLayer(raw_datas, elem_count, dequant_datas, quant_params);
131 if (ret != RET_OK) {
132 MS_LOG(ERROR) << "Cal weight quant bias failed.";
133 return ret;
134 }
135 }
136 return RET_OK;
137 }
138
139 // Get the index of the bucket to which the current data belongs.
140 int GetBucketIndex(const std::vector<int> &dims, int preferred_dim, int data_index);
141
142 // Calculate the Compression effect of per-channel
143 int CalPerChannelGain(size_t bit_num, const std::vector<int> &dims, int preferred_dim);
144
145 // Get the min max of each channel
146 void GetAllChannelMinMax(const float *raw_datas, size_t elem_count, const std::vector<int> &dims, int preferred_dim,
147 std::map<int, MinMax> *per_channel_min_max);
148
149 // Calculate the distribution difference between quant and origin
150 int CalWeightQuantBias(const float *raw_datas, size_t elem_count, const std::vector<float> &dequant_datas,
151 std::vector<schema::QuantParamT> *quant_params, const std::vector<int> &dims, int preferred_dim);
152
153 template <typename T>
154 int DoPerChannelQuant(const float *raw_datas, size_t elem_count, std::vector<schema::QuantParamT> *quant_params,
155 int quant_max, int quant_min, size_t bit_num, std::vector<T> *quant_datas,
156 const std::vector<int> &dims, int preferred_dim, bool cal_gain = true, bool symmetric = false,
157 bool narrow_range = false) {
158 if (raw_datas == nullptr || quant_params == nullptr || quant_datas == nullptr) {
159 MS_LOG(ERROR) << "raw_data, quant_params or quant_data is nullptr.";
160 return RET_ERROR;
161 }
162 int ret;
163 auto count = std::accumulate(std::begin(dims), std::end(dims), 1, std::multiplies<>());
164 if (static_cast<size_t>(count) != elem_count) {
165 MS_LOG(ERROR) << " element != count";
166 return RET_ERROR;
167 }
168
169 CHECK_LESS_RETURN(dims.size(), static_cast<size_t>(preferred_dim + 1));
170 if (cal_gain) {
171 ret = CalPerChannelGain(bit_num, dims, preferred_dim);
172 if (ret == RET_NO_CHANGE) {
173 return RET_NO_CHANGE;
174 }
175 }
176
177 std::vector<float> dequant_datas(quant_datas->size());
178 // the key is bucket_index
179 std::map<int, MinMax> per_channel_min_max;
180 GetAllChannelMinMax(raw_datas, elem_count, dims, preferred_dim, &per_channel_min_max);
181
182 // Cal Quant param
183 for (auto min_max_map : per_channel_min_max) {
184 float min = min_max_map.second.min;
185 float max = min_max_map.second.max;
186 schema::QuantParamT quant_param;
187 ret = CalQuantizationParams(&quant_param, min, max, bit_num, quant_min, quant_max, symmetric, narrow_range);
188 if (ret != RET_OK) {
189 MS_LOG(ERROR) << "Cal quantization params failed.";
190 return ret;
191 }
192 quant_params->emplace_back(quant_param);
193 }
194 // Do quant
195 for (size_t i = 0; i < elem_count; i++) {
196 float raw_data = raw_datas[i];
197 auto bucket_index = GetBucketIndex(dims, preferred_dim, i);
198 MS_CHECK_GT(static_cast<int>(quant_params->size()), bucket_index, RET_ERROR);
199 auto quant_param = quant_params->at(bucket_index);
200 auto quant_data = QuantizeData<T>(raw_data, &quant_param, quant_max, quant_min);
201 (*quant_datas)[i] = quant_data;
202 // cal dequant(use for cal weight bias)
203 dequant_datas.at(i) = quant_param.scale * (quant_data - quant_param.zeroPoint);
204 }
205
206 if (cal_gain) {
207 ret = CalWeightQuantBias(raw_datas, elem_count, dequant_datas, quant_params, dims, preferred_dim);
208 if (ret != RET_OK) {
209 MS_LOG(ERROR) << "Cal weight quant bias failed.";
210 return ret;
211 }
212 }
213 return RET_OK;
214 }
215 } // namespace lite
216 } // namespace mindspore
217
218 #endif // MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
219