• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/quantizer/data_distribution.h"
18 #include <algorithm>
19 #include <vector>
20 #include <utility>
21 #include <set>
22 #include <cmath>
23 #include "tools/common/statistic_utils.h"
24 #include "src/common/utils.h"
25 
26 namespace mindspore::lite::quant {
RecordMaxMinValueArray(const std::vector<float> & data)27 int DataDistribution::RecordMaxMinValueArray(const std::vector<float> &data) {
28   if (data.empty()) {
29     return RET_ERROR;
30   }
31   auto min_max = GetFloatMinMaxValue(data.data(), data.size());
32   float min_num = min_max.first;
33   float max_num = min_max.second;
34   real_min_ = std::min(min_num, real_min_);
35   real_max_ = std::max(max_num, real_max_);
36   if (activation_quant_method_ == REMOVAL_OUTLIER) {
37     auto bak_data(data);
38     const float min_percentage = 0.0001;
39     const float max_percentage = 0.9999;
40     auto const quantile_min_index = static_cast<int>(min_percentage * bak_data.size());
41     auto const quantile_max_index = static_cast<int>(max_percentage * bak_data.size());
42     std::nth_element(bak_data.begin(), bak_data.begin() + quantile_min_index, bak_data.end());
43     auto quantile_min = bak_data.at(quantile_min_index);
44     std::nth_element(bak_data.begin() + quantile_min_index + 1, bak_data.begin() + quantile_max_index, bak_data.end());
45     auto quantile_max = bak_data.at(quantile_max_index);
46     MS_LOG(DEBUG) << "real_min_:" << real_min_ << " real_max_:" << real_max_ << " quantile_min:" << quantile_min
47                   << " quantile_max:" << quantile_max;
48     this->min_datas_.emplace_back(quantile_min);
49     this->max_datas_.emplace_back(quantile_max);
50   }
51   return RET_OK;
52 }
53 
UpdateInterval()54 void DataDistribution::UpdateInterval() {
55   auto max_value = std::max(fabs(this->real_max_), fabs(this->real_min_));
56   MS_CHECK_TRUE_RET_VOID(bin_num_ != 0);
57   this->interval_ = max_value / static_cast<float>(bin_num_);
58 }
59 
UpdateHistogram(const std::vector<float> & data)60 int DataDistribution::UpdateHistogram(const std::vector<float> &data) {
61   for (auto value : data) {
62     if (fabs(value) <= DBL_EPSILON) {
63       continue;
64     }
65     if (fabs(this->interval_) <= DBL_EPSILON) {
66       MS_LOG(ERROR) << "divisor 'interval' cannot be 0.";
67       return RET_ERROR;
68     }
69     int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval_), bin_num_ - 1);
70     this->histogram_[bin_index]++;
71   }
72   return RET_OK;
73 }
74 
DumpHistogram()75 void DataDistribution::DumpHistogram() {
76   MS_LOG(INFO) << "Print node " << cnode_->fullname_with_scope() << " histogram";
77   for (float item : this->histogram_) {
78     std::cout << item << " ";
79   }
80   std::cout << std::endl;
81 }
82 
HandleBinForKL(int quant_bint_nums,int bin_index,std::vector<float> * quantized_histogram,std::vector<float> * expanded_histogram)83 void DataDistribution::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
84                                       std::vector<float> *expanded_histogram) {
85   CHECK_NULL_RETURN_VOID(quantized_histogram);
86   CHECK_NULL_RETURN_VOID(expanded_histogram);
87   MS_CHECK_TRUE_RET_VOID(quant_bint_nums != 0);
88   const float bin_interval = static_cast<float>(bin_index) / static_cast<float>(quant_bint_nums);
89   MS_CHECK_TRUE_RET_VOID(quant_bint_nums <= static_cast<int>(quantized_histogram->size()));
90   // merge i bins to target bins
91   for (int i = 0; i < quant_bint_nums; ++i) {
92     const float start = i * bin_interval;
93     const float end = start + bin_interval;
94     const int left_upper = static_cast<int>(std::ceil(start));
95     if (left_upper > start) {
96       const double left_scale = left_upper - start;
97       MS_ASSERT((left_upper - 1) < this->histogram_.size());
98       quantized_histogram->at(i) += left_scale * this->histogram_[left_upper - 1];
99     }
100     const int right_lower = static_cast<int>(std::floor(end));
101     if (right_lower < end) {
102       const double right_scale = end - right_lower;
103       quantized_histogram->at(i) += right_scale * this->histogram_[right_lower];
104     }
105     std::for_each(this->histogram_.begin() + left_upper, this->histogram_.begin() + right_lower,
106                   [&quantized_histogram, i](float item) { quantized_histogram->at(i) += item; });
107   }
108   // expand target bins to i bins in order to calculate KL with reference_histogram
109   for (int i = 0; i < quant_bint_nums; ++i) {
110     const float start = i * bin_interval;
111     const float end = start + bin_interval;
112     float count = 0;
113     const int left_upper = static_cast<int>(std::ceil(start));
114     float left_scale = 0.0f;
115     if (left_upper > start) {
116       left_scale = left_upper - start;
117       if (fabs(this->histogram_[left_upper - 1]) > DBL_EPSILON) {
118         count += left_scale;
119       }
120     }
121     const int right_lower = static_cast<int>(std::floor(end));
122     double right_scale = 0.0f;
123     if (right_lower < end) {
124       right_scale = end - right_lower;
125       if (fabs(this->histogram_[right_lower]) > DBL_EPSILON) {
126         count += right_scale;
127       }
128     }
129     std::for_each(this->histogram_.begin() + left_upper, this->histogram_.begin() + right_lower, [&count](float item) {
130       bool is_zero = (item <= kEps && item >= -kEps);
131       if (!is_zero) {
132         count += 1;
133       }
134     });
135     if (fabs(count) <= DBL_EPSILON) {
136       continue;
137     }
138     const float average_num = quantized_histogram->at(i) / count;
139     if (left_upper > start && fabs(this->histogram_[left_upper - 1]) > DBL_EPSILON) {
140       expanded_histogram->at(left_upper - 1) += average_num * left_scale;
141     }
142     if (right_lower < end && fabs(this->histogram_[right_lower]) > DBL_EPSILON) {
143       expanded_histogram->at(right_lower) += average_num * right_scale;
144     }
145     for (int k = left_upper; k < right_lower; ++k) {
146       if (fabs(this->histogram_[k]) > DBL_EPSILON) {
147         expanded_histogram->at(k) += average_num;
148       }
149     }
150   }
151 }
152 
ComputeThreshold()153 int DataDistribution::ComputeThreshold() {
154   if (activation_quant_method_ != KL) {
155     return RET_OK;
156   }
157 
158   int threshold = INT8_MAX + 1;
159   float min_kl = FLT_MAX;
160   float after_threshold_sum = std::accumulate(this->histogram_.begin() + INT8_MAX + 1, this->histogram_.end(), 0.0f);
161 
162   for (int i = INT8_MAX + 1; i < this->bin_num_; ++i) {
163     std::vector<float> quantized_histogram(INT8_MAX + 1, 0);
164     std::vector<float> reference_histogram(this->histogram_.begin(), this->histogram_.begin() + i);
165     std::vector<float> expanded_histogram(i, 0);
166     reference_histogram[i - 1] += after_threshold_sum;
167     after_threshold_sum -= this->histogram_[i];
168     // handle bins for computing KL.
169     HandleBinForKL(INT8_MAX + 1, i, &quantized_histogram, &expanded_histogram);
170     const float kl = lite::KLDivergence(reference_histogram, expanded_histogram);
171     if (kl < min_kl) {
172       min_kl = kl;
173       threshold = i;
174     }
175   }
176   this->best_T_ = (static_cast<float>(threshold) + 0.5f) * this->interval_;
177   MS_LOG(DEBUG) << cnode_->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T_
178                 << " max: " << std::max(fabs(this->real_max_), fabs(this->real_min_));
179   return RET_OK;
180 }
181 
CalculateMinMaxScale()182 double DataDistribution::CalculateMinMaxScale() { return CalculateScale(this->real_min_, this->real_max_); }
183 
CalculateRemovalOutlierScale()184 double DataDistribution::CalculateRemovalOutlierScale() {
185   this->percent_result_ = CalQuantileMinMax(min_datas_, max_datas_);
186   return CalculateScale(percent_result_.first, percent_result_.second);
187 }
188 
CalQuantileMinMax(const std::vector<float> & min_datas,const std::vector<float> & max_datas)189 std::pair<float, float> DataDistribution::CalQuantileMinMax(const std::vector<float> &min_datas,
190                                                             const std::vector<float> &max_datas) {
191   MS_ASSERT(!min_datas.empty());
192   MS_ASSERT(!max_datas.empty());
193   auto avg_min = accumulate(min_datas.begin(), min_datas.end(), 0.0) / min_datas.size();
194   auto avg_max = accumulate(max_datas.begin(), max_datas.end(), 0.0) / max_datas.size();
195   return {avg_min, avg_max};
196 }
197 
CalculateScale(float min_value,float max_value)198 double DataDistribution::CalculateScale(float min_value, float max_value) {
199   EncodeMinMax(min_value, max_value, quant_min_, quant_max_, symmetric_, &encode_min_, &encode_max_);
200   auto q_range = quant_max_ - quant_min_;
201   MS_ASSERT(quant_max_ - quant_min_ > 0);
202   auto range = encode_max_ - encode_min_;
203   return range / q_range;
204 }
205 
CalculateKLScale()206 double DataDistribution::CalculateKLScale() {
207   return CalculateScale(-std::abs(this->best_T_), std::abs(this->best_T_));
208 }
209 
GetScale()210 double DataDistribution::GetScale() {
211   switch (this->activation_quant_method_) {
212     case MAX_MIN:
213       this->scale_ = CalculateMinMaxScale();
214       break;
215     case KL:
216       this->scale_ = CalculateKLScale();
217       break;
218     case REMOVAL_OUTLIER:
219       this->scale_ = CalculateRemovalOutlierScale();
220       break;
221     default:
222       MS_LOG(ERROR) << "Unsupported activation quant method " << this->activation_quant_method_;
223       return FLT_MAX;
224   }
225   return this->scale_;
226 }
227 
GetZeroPoint()228 int32_t DataDistribution::GetZeroPoint() {
229   if (symmetric_) {
230     zero_point_ = 0;
231   } else {
232     MS_ASSERT(scale_ > 0);
233     zero_point_ = static_cast<int32_t>(std::round(quant_min_ - encode_min_ / scale_));
234   }
235   return zero_point_;
236 }
237 }  // namespace mindspore::lite::quant
238