• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
18 #define MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
19 
20 #include <float.h>
21 #include <cmath>
22 #include <climits>
23 #include <limits>
24 #include <algorithm>
25 #include <vector>
26 #include "include/errorcode.h"
27 #include "src/common/log_adapter.h"
28 #include "ir/dtype/type_id.h"
29 
30 namespace mindspore {
31 
32 namespace schema {
33 struct QuantParamT;
34 }
35 
36 namespace lite {
37 const int RET_QUANT_CONTINUE = 2;
38 static constexpr double SCALE_THREASHOLD = 1e-38;
39 
40 static constexpr int kPerTensor = 1;
41 
QuantMax(int bits,TypeId type)42 inline int QuantMax(int bits, TypeId type) {
43   if (type == kNumberTypeInt8) {
44     return (1 << (bits - 1)) - 1;
45   } else if (type == kNumberTypeUInt8) {
46     return (1 << bits) - 1;
47   }
48   return 0;
49 }
50 
QuantMin(int bits,TypeId type)51 inline int QuantMin(int bits, TypeId type) {
52   if (type == kNumberTypeInt8) {
53     return -(1 << (bits - 1));
54   }
55   return 0;
56 }
57 
58 STATUS GetMaxMinPerChannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas,
59                            bool channel_at_first, float *desired_max, float *desired_min);
60 
61 STATUS CalQuantizationParams(schema::QuantParamT *quant_param, double real_min, double real_max, bool narrow_range,
62                              int quant_max, int quant_min, int num_bits);
63 
64 template <typename T>
QuantizeData(const float originData,const schema::QuantParamT * quantParam)65 T QuantizeData(const float originData, const schema::QuantParamT *quantParam) {
66   MS_ASSERT(quantParam != nullptr);
67   MS_ASSERT(quantParam->inited);
68   const auto scale = quantParam->scale;
69   const auto zeroPoint = quantParam->zeroPoint;
70   const auto numBit = quantParam->numBits;
71   const auto narrowRange = quantParam->narrowRange;
72   const int32_t quantMax = (1 << (unsigned int)(numBit - 1)) - 1;
73   const int32_t quantMin = -1 * (1 << (unsigned int)(numBit - 1)) + (narrowRange ? 1 : 0);
74   const double maxLimit = static_cast<float>(quantMax - zeroPoint) * scale;
75   const double minLimit = static_cast<float>(quantMin - zeroPoint) * scale;
76 
77   return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] {
78     double tmp;
79     if (originData > maxLimit) {
80       tmp = maxLimit;
81     } else if (originData < minLimit) {
82       tmp = minLimit;
83     } else {
84       tmp = originData;
85     }
86     auto quantData = static_cast<T>(std::round(zeroPoint + tmp / scale));
87     return quantData;
88   }();
89 }
90 
91 template <typename T>
QuantizeData(float originData,const schema::QuantParamT * quantParam,int quant_max,int quant_min)92 T QuantizeData(float originData, const schema::QuantParamT *quantParam, int quant_max, int quant_min) {
93   MS_ASSERT(quantParam != nullptr);
94   MS_ASSERT(quantParam->inited);
95   const auto scale = quantParam->scale;
96   const int zeroPoint = quantParam->zeroPoint;
97   const int maxLimit = quant_max;
98   const int minLimit = quant_min;
99 
100   if (scale <= SCALE_THREASHOLD) {
101     return 0;
102   }
103 
104   return [maxLimit, minLimit, zeroPoint, scale, originData] {
105     auto quant_data = std::round(originData / scale + zeroPoint);
106     if (quant_data > maxLimit) {
107       quant_data = maxLimit;
108     } else if (quant_data < minLimit) {
109       quant_data = minLimit;
110     }
111     return static_cast<T>(quant_data);
112   }();
113 }
114 
115 template <typename T>
DoPerLayerQuant(const float * raw_datas,size_t elem_count,std::vector<schema::QuantParamT> * quant_params,const int & quant_max,const int & quant_min,const size_t & bit_num,const bool & k_means,std::vector<T> * quant_datas)116 STATUS DoPerLayerQuant(const float *raw_datas, size_t elem_count, std::vector<schema::QuantParamT> *quant_params,
117                        const int &quant_max, const int &quant_min, const size_t &bit_num, const bool &k_means,
118                        std::vector<T> *quant_datas) {
119   float min = FLT_MAX;
120   float max = -FLT_MIN;
121   for (uint32_t i = 0; i < elem_count; i++) {
122     min = std::min(min, raw_datas[i]);
123     max = std::max(max, raw_datas[i]);
124   }
125 
126   schema::QuantParamT quant_param;
127   if (!k_means) {
128     STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num);
129     if (status != RET_OK) {
130       MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
131       return status;
132     }
133   }
134   quant_params->emplace_back(quant_param);
135   // update data and datatype
136   for (uint32_t i = 0; i < elem_count; i++) {
137     float raw_data = raw_datas[i];
138     if (!k_means) {
139       auto quant_data = QuantizeData<T>(raw_data, &quant_param, quant_max, quant_min);
140       (*quant_datas)[i] = quant_data;
141     }
142   }
143   return RET_OK;
144 }
145 
146 template <typename T>
147 STATUS DoPerChannelQuant(const float *raw_datas, size_t elem_count, const schema::QuantType &quant_type,
148                          std::vector<schema::QuantParamT> *quant_params, const int &quant_max, const int &quant_min,
149                          const size_t &bit_num, const bool &k_means, std::vector<T> *quant_datas, int channels,
150                          bool channel_at_first = true) {
151   static const int quant_param_size = 32 * 8;
152   std::vector<float> dequant_datas(quant_datas->size());
153   if (channels <= 0) {
154     MS_LOG(ERROR) << "channels must be greater than 0";
155     return RET_ERROR;
156   }
157   size_t one_filter_size = elem_count / channels;
158   bool do_quant = quant_param_size / (sizeof(float) * 8 - bit_num) < one_filter_size;
159   if (!do_quant && quant_type == schema::QuantType_QUANT_WEIGHT) {
160     MS_LOG(INFO) << "too few elements in a filter, no need to quantize. " << one_filter_size;
161     return RET_QUANT_CONTINUE;
162   }
163   for (int i = 0; i < channels; i++) {
164     float min = FLT_MAX;
165     float max = -FLT_MAX;
166     STATUS status =
167       GetMaxMinPerChannel(channels, one_filter_size, i, elem_count, raw_datas, channel_at_first, &max, &min);
168     if (status != RET_OK) {
169       MS_LOG(ERROR) << "GetMaxMinPerChannel failed" << status;
170       return status;
171     }
172     schema::QuantParamT quant_param;
173     status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bit_num);
174     if (status != RET_OK) {
175       MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
176       return status;
177     }
178     // do quantization
179     double average_dequant = 0;
180     double average_raw = 0;
181     for (uint32_t j = 0; j < one_filter_size; j++) {
182       auto index = j + i * one_filter_size;
183       if (!channel_at_first) {
184         index = j * channels + i;
185       }
186       MS_ASSERT(index < elem_count);
187       float raw_data = raw_datas[index];
188       auto quant_data = QuantizeData<T>(raw_data, &quant_param, quant_max, quant_min);
189       (*quant_datas)[index] = quant_data;
190 
191       if (quant_type == schema::QuantType_QUANT_WEIGHT) {
192         float dequant_data = quant_param.scale * (quant_data - quant_param.zeroPoint);
193         dequant_datas[index] = dequant_data;
194         average_dequant += dequant_data;
195         average_raw += raw_data;
196       }
197     }
198     if (quant_type == schema::QuantType_QUANT_WEIGHT && !k_means) {
199       // mean
200       average_dequant = average_dequant / one_filter_size;
201       average_raw = average_raw / one_filter_size;
202       // std
203       double variance_dequant = 0;
204       double variance_raw = 0;
205       for (uint32_t j = 0; j < one_filter_size; j++) {
206         auto index = j + i * one_filter_size;
207         if (!channel_at_first) {
208           index = j * channels + i;
209         }
210         MS_ASSERT(index < elem_count);
211         variance_dequant += std::pow(dequant_datas[index] - average_dequant, 2);
212         variance_raw += std::pow(raw_datas[index] - average_raw, 2);
213       }
214       variance_dequant = std::sqrt(variance_dequant / one_filter_size);
215       variance_raw = std::sqrt(variance_raw / one_filter_size);
216       quant_param.varCorr = 1;
217       if (variance_raw != 0 && variance_dequant != 0) {
218         auto temp_var_corr = variance_raw / variance_dequant;
219         if (temp_var_corr > 0 && temp_var_corr < 10) {
220           quant_param.varCorr = temp_var_corr;
221         } else {
222           MS_LOG(WARNING) << "unexpected var_corr: " << temp_var_corr;
223         }
224       }
225       quant_param.meanCorr = average_raw - average_dequant * quant_param.varCorr;
226     }
227     quant_params->emplace_back(quant_param);
228   }
229   return RET_OK;
230 }
231 }  // namespace lite
232 }  // namespace mindspore
233 
234 #endif  // MINDSPORE_LITE_SRC_COMMON_QUANT_UTILS_H_
235