• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/quantizer/full_quant_quantizer.h"
18 #include <dirent.h>
19 #include <future>
20 #include <map>
21 #include <set>
22 #include <memory>
23 #include <algorithm>
24 #include <unordered_map>
25 #include <functional>
26 #include <numeric>
27 #include <utility>
28 #include <string>
29 #include <thread>
30 #include <vector>
31 #include <fstream>
32 #include "ops/fusion/full_connection.h"
33 #include "tools/converter/ops/ops_def.h"
34 #include "src/tensor.h"
35 #include "tools/converter/quantizer/quant_cast.h"
36 #include "tools/converter/quantizer/quantize_util.h"
37 #include "tools/optimizer/common/gllo_utils.h"
38 #include "tools/optimizer/common/format_utils.h"
39 #include "src/common/log_adapter.h"
40 #include "securec/include/securec.h"
41 #include "tools/common/tensor_util.h"
42 #include "src/common/quant_utils.h"
43 #include "src/common/utils.h"
44 #include "tools/converter/preprocess/image_preprocess.h"
45 #include "nnacl/op_base.h"
46 
47 using std::string;
48 using std::vector;
49 
50 namespace mindspore::lite::quant {
51 namespace {
52 static const std::set<PrimitivePtr> has_bias_operator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
53                                                          prim::kPrimMatMul, prim::kPrimFullConnection,
54                                                          prim::kPrimLayerNormFusion};
55 constexpr int kMinSize = 0;
56 constexpr int kMaxSize = 65535;
57 }  // namespace
58 namespace {
ComputeBiasDataAndQuantParam(const std::vector<double> & bias_scales,const std::vector<double> & input_scales,const float * raw_datas,const QuantParamHolderPtr & quant_param_holder,std::vector<schema::QuantParamT> * quant_params,std::vector<int32_t> * quant_datas)59 STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const std::vector<double> &input_scales,
60                                     const float *raw_datas, const QuantParamHolderPtr &quant_param_holder,
61                                     std::vector<schema::QuantParamT> *quant_params, std::vector<int32_t> *quant_datas) {
62   MS_ASSERT(raw_datas != nullptr && quant_param_holder != nullptr);
63   MS_ASSERT(quant_params != nullptr && quant_datas != nullptr);
64   double bias_scale_tmp;
65   const constexpr double quanted_bias_abs_limit = 0.5 * INT32_MAX;
66   MS_CHECK_TRUE_MSG(quant_param_holder->get_input_quant_params().size() > 1, RET_ERROR, "invalid access.");
67   auto weight_quant_params = quant_param_holder->get_input_quant_params().at(1);
68   auto shape_size = quant_datas->size();
69   if (bias_scales.size() == shape_size) {
70     for (size_t i = 0; i < shape_size; i++) {
71       bias_scale_tmp = bias_scales[i];
72       if (fabs(bias_scale_tmp) <= 0.0f) {
73         MS_LOG(ERROR) << "divisor 'bias_scale_tmp' cannot be 0.";
74         return RET_ERROR;
75       }
76       if (std::abs(raw_datas[i] / bias_scale_tmp) >= quanted_bias_abs_limit) {
77         MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << weight_quant_params[i].scale
78                       << " is too small, need to update";
79         // update filter scale and zp
80         double activate_scale = input_scales[0];
81         double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit);
82         weight_quant_params[i].scale = filter_scale;
83         weight_quant_params[i].zeroPoint = 0;
84         quant_param_holder->set_input_quant_param(1, weight_quant_params);
85         bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit;
86         quant_params->at(i).scale = bias_scale_tmp;
87         MS_LOG(DEBUG) << "new filter scale: " << filter_scale;
88       }
89       auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp);
90       quant_datas->at(i) = quant_data;
91     }
92     return RET_OK;
93   } else if (bias_scales.size() == 1) {
94     // for fc, per tensor quant
95     bias_scale_tmp = quant_params->front().scale;
96     float max_raw_data = 0.0f;
97     for (size_t i = 0; i < shape_size; i++) {
98       if (std::abs(raw_datas[i]) > max_raw_data) {
99         max_raw_data = std::abs(raw_datas[i]);
100       }
101     }
102     if (fabs(bias_scale_tmp) <= 0.0f) {
103       MS_LOG(ERROR) << "divisor 'bias_scale_tmp' cannot be 0.";
104       return RET_ERROR;
105     }
106     if (std::abs(max_raw_data / bias_scale_tmp) >= quanted_bias_abs_limit) {
107       MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << weight_quant_params[0].scale
108                     << " is too small, need to update";
109       double activate_scale = input_scales[0];
110       MS_CHECK_TRUE_MSG(activate_scale != 0, RET_ERROR, "activate_scale == 0");
111       double filter_scale = std::abs(max_raw_data) / (activate_scale * quanted_bias_abs_limit);
112       weight_quant_params[0].scale = filter_scale;
113       weight_quant_params[0].zeroPoint = 0;
114       quant_param_holder->set_input_quant_param(1, weight_quant_params);
115       bias_scale_tmp = max_raw_data / quanted_bias_abs_limit;
116       quant_params->front().scale = bias_scale_tmp;
117       MS_LOG(DEBUG) << "new filter scale: " << filter_scale;
118     }
119     for (size_t i = 0; i < shape_size; i++) {
120       auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp);
121       quant_datas->at(i) = quant_data;
122     }
123     return RET_OK;
124   }
125   MS_LOG(ERROR) << "unexpected input_scales size: " << input_scales.size()
126                 << " weight_scales size: " << weight_quant_params.size();
127   return RET_ERROR;
128 }
129 }  // namespace
130 
RecordMaxMinValue(const std::vector<float> & data)131 STATUS DivergInfo::RecordMaxMinValue(const std::vector<float> &data) {
132   for (float val : data) {
133     max = std::max(val, max);
134     min = std::min(val, min);
135   }
136   return RET_OK;
137 }
138 
RecordMaxMinValueArray(const std::vector<float> & data)139 STATUS DivergInfo::RecordMaxMinValueArray(const std::vector<float> &data) {
140   if (data.empty()) {
141     return RET_ERROR;
142   }
143   float max_num = data.at(0);
144   float min_num = data.at(0);
145   for (float val : data) {
146     max_num = std::max(val, max_num);
147     min_num = std::min(val, min_num);
148   }
149   this->max_datas.emplace_back(max_num);
150   this->min_datas.emplace_back(min_num);
151   return RET_OK;
152 }
153 
UpdateInterval()154 void DivergInfo::UpdateInterval() {
155   auto max_value = std::max(fabs(this->max), fabs(this->min));
156   MS_ASSERT(bin_num != 0);
157   this->interval = max_value / static_cast<float>(bin_num);
158 }
159 
UpdateHistogram(const std::vector<float> & data)160 STATUS DivergInfo::UpdateHistogram(const std::vector<float> &data) {
161   for (auto value : data) {
162     if (value == 0) {
163       continue;
164     }
165     if (this->interval == 0) {
166       MS_LOG(ERROR) << "divisor 'interval' cannot be 0.";
167       return RET_ERROR;
168     }
169     int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1);
170     this->histogram[bin_index]++;
171   }
172   return RET_OK;
173 }
174 
DumpHistogram()175 void DivergInfo::DumpHistogram() {
176   MS_LOG(INFO) << "Print node " << cnode->fullname_with_scope() << " histogram";
177   for (float item : this->histogram) {
178     std::cout << item << " ";
179   }
180   std::cout << std::endl;
181 }
182 
HandleBinForKL(int quant_bint_nums,int bin_index,std::vector<float> * quantized_histogram,std::vector<float> * expanded_histogram)183 void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
184                                 std::vector<float> *expanded_histogram) {
185   MS_ASSERT(quantized_histogram != nullptr && expanded_histogram != nullptr);
186   MS_ASSERT(quant_bint_nums != 0);
187   const float bin_interval = static_cast<float>(bin_index) / static_cast<float>(quant_bint_nums);
188   // merge i bins to target bins
189   for (int i = 0; i < quant_bint_nums; ++i) {
190     const float start = i * bin_interval;
191     const float end = start + bin_interval;
192     const int left_upper = static_cast<int>(std::ceil(start));
193     if (left_upper > start) {
194       const double left_scale = left_upper - start;
195       quantized_histogram->at(i) += left_scale * this->histogram[left_upper - 1];
196     }
197     const int right_lower = static_cast<int>(std::floor(end));
198     if (right_lower < end) {
199       const double right_scale = end - right_lower;
200       quantized_histogram->at(i) += right_scale * this->histogram[right_lower];
201     }
202     std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower,
203                   [&quantized_histogram, i](float item) { quantized_histogram->at(i) += item; });
204   }
205   // expand target bins to i bins in order to calculate KL with reference_histogram
206   for (int i = 0; i < quant_bint_nums; ++i) {
207     const float start = i * bin_interval;
208     const float end = start + bin_interval;
209     float count = 0;
210     const int left_upper = static_cast<int>(std::ceil(start));
211     float left_scale = 0.0f;
212     if (left_upper > start) {
213       left_scale = left_upper - start;
214       if (this->histogram[left_upper - 1] != 0) {
215         count += left_scale;
216       }
217     }
218     const int right_lower = static_cast<int>(std::floor(end));
219     double right_scale = 0.0f;
220     if (right_lower < end) {
221       right_scale = end - right_lower;
222       if (this->histogram[right_lower] != 0) {
223         count += right_scale;
224       }
225     }
226     std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, [&count](float item) {
227       if (item != 0) {
228         count += 1;
229       }
230     });
231     if (count == 0) {
232       continue;
233     }
234     const float average_num = quantized_histogram->at(i) / count;
235     if (left_upper > start && this->histogram[left_upper - 1] != 0) {
236       expanded_histogram->at(left_upper - 1) += average_num * left_scale;
237     }
238     if (right_lower < end && this->histogram[right_lower] != 0) {
239       expanded_histogram->at(right_lower) += average_num * right_scale;
240     }
241     for (int k = left_upper; k < right_lower; ++k) {
242       if (this->histogram[k] != 0) {
243         expanded_histogram->at(k) += average_num;
244       }
245     }
246   }
247 }
248 
ComputeThreshold()249 STATUS DivergInfo::ComputeThreshold() {
250   if (activation_quant_method == MAX_MIN) {
251     this->best_T = std::max(fabs(this->max), fabs(this->min));
252     MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T;
253     return RET_OK;
254   }
255 
256   if (activation_quant_method == REMOVAL_OUTLIER && !this->min_datas.empty()) {
257     this->percent_result = OutlierMethod(min_datas, max_datas);
258     this->best_T = std::max(std::fabs(percent_result.first), std::fabs(percent_result.second));
259     return RET_OK;
260   }
261 
262   int threshold = INT8_MAX + 1;
263   float min_kl = FLT_MAX;
264   float after_threshold_sum = std::accumulate(this->histogram.begin() + INT8_MAX + 1, this->histogram.end(), 0.0f);
265 
266   for (int i = INT8_MAX + 1; i < this->bin_num; ++i) {
267     std::vector<float> quantized_histogram(INT8_MAX + 1, 0);
268     std::vector<float> reference_histogram(this->histogram.begin(), this->histogram.begin() + i);
269     std::vector<float> expanded_histogram(i, 0);
270     reference_histogram[i - 1] += after_threshold_sum;
271     after_threshold_sum -= this->histogram[i];
272     // handle bins for computing KL.
273     HandleBinForKL(INT8_MAX + 1, i, &quantized_histogram, &expanded_histogram);
274     auto KLDivergence = [](std::vector<float> p, std::vector<float> q) {
275       auto sum = 0.0f;
276       std::for_each(p.begin(), p.end(), [&sum](float item) { sum += item; });
277       std::for_each(p.begin(), p.end(), [sum](float &item) { item /= sum; });
278       sum = 0.0f;
279       std::for_each(q.begin(), q.end(), [&sum](float item) { sum += item; });
280       std::for_each(q.begin(), q.end(), [sum](float &item) { item /= sum; });
281 
282       float result = 0.0f;
283       const int size = p.size();
284       for (int i = 0; i < size; ++i) {
285         if (p[i] != 0) {
286           if (q[i] == 0) {
287             result += 1.0f;
288           } else {
289             result += (p[i] * std::log((p[i]) / (q[i])));
290           }
291         }
292       }
293       return result;
294     };
295     const float kl = KLDivergence(reference_histogram, expanded_histogram);
296     if (kl < min_kl) {
297       min_kl = kl;
298       threshold = i;
299     }
300   }
301   this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval;
302   MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T
303                 << " max: " << std::max(fabs(this->max), fabs(this->min));
304   return RET_OK;
305 }
306 
GetScale()307 std::pair<CNodePtr, float> DivergInfo::GetScale() {
308   float max_value = this->best_T;
309   float min_value = -max_value;
310 
311   if (this->activation_quant_method == REMOVAL_OUTLIER) {
312     min_value = percent_result.first;
313     max_value = percent_result.second;
314   }
315 
316   MS_CHECK_TRUE_MSG(quant_max - quant_min != 0, {}, "quant_max - quant_min == 0");
317   float scale = (max_value - min_value) / (quant_max - quant_min);
318   this->scale_tmp = scale;
319   MS_ASSERT(fabs(scale) <= 0.0f);
320   return std::make_pair(this->cnode, scale);
321 }
322 
GetZeropoint()323 std::pair<CNodePtr, int32_t> DivergInfo::GetZeropoint() {
324   int zero_point = 0;
325   if (quant_min == 0 && quant_max == UINT8_MAX) {
326     zero_point = INT8_MAX + 1;
327   } else if (quant_min == INT_LEAST8_MIN + 1 && quant_max == INT8_MAX) {
328     zero_point = 0;
329   } else {
330     MS_LOG(WARNING) << "unexpected quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
331   }
332   if (this->activation_quant_method == REMOVAL_OUTLIER) {
333     MS_CHECK_TRUE_MSG(fabs(scale_tmp) <= 0.0f, {}, "fabs(scale_tmp) > 0.0f");
334     zero_point = std::round(quant_max - percent_result.second / scale_tmp);
335   }
336   return std::make_pair(this->cnode, zero_point);
337 }
338 
GetInputDivergInfo()339 std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetInputDivergInfo() {
340   return &this->inputs_diverg_info_;
341 }
342 
GetOutputDivergInfo()343 std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetOutputDivergInfo() {
344   return &this->outputs_diverg_info_;
345 }
346 
RecordMaxMinValue(const vector<float> & data,const std::unique_ptr<DivergInfo> & diverg_info)347 STATUS Calibrator::RecordMaxMinValue(const vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
348   auto ret = diverg_info->RecordMaxMinValue(data);
349   if (ret != RET_OK) {
350     MS_LOG(ERROR) << "Record max min value failed.";
351     return ret;
352   }
353   ret = diverg_info->RecordMaxMinValueArray(data);
354   if (ret != RET_OK) {
355     MS_LOG(ERROR) << "Record max min value array failed.";
356     return ret;
357   }
358   return RET_OK;
359 }
360 
ComputeThreshold()361 STATUS Calibrator::ComputeThreshold() {
362   for (auto &kv : this->outputs_diverg_info_) {
363     auto &outputs_diverg_info = kv.second;
364     for (auto &diverg_info : outputs_diverg_info) {
365       auto ret = diverg_info->ComputeThreshold();
366       if (ret != RET_OK) {
367         MS_LOG(ERROR) << "Compute threshold failed.";
368         return ret;
369       }
370     }
371   }
372   // node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as
373   for (auto &kv : this->inputs_diverg_info_) {
374     auto &input_infos = kv.second;
375     for (size_t i = 0; i < input_infos.size(); i++) {
376       auto cnode = input_infos[i]->cnode;
377       bool already_computed = false;
378       auto input = cnode->input(i + 1);
379       if (input->isa<mindspore::CNode>()) {
380         auto input_cnode = input->cast<CNodePtr>();
381         for (const auto &outputs_diverg_info : outputs_diverg_info_) {
382           if (already_computed) {
383             break;
384           }
385           for (const auto &output_diverg_info : outputs_diverg_info.second) {
386             auto output_diverg_cnode = output_diverg_info->cnode;
387             if (output_diverg_cnode == input_cnode) {
388               if (NodePrimitiveType(input_cnode) != lite::kNameTupleGetItem) {
389                 *(input_infos[i]) = *output_diverg_info;
390                 input_infos[i]->cnode = cnode;
391                 already_computed = true;
392                 break;
393               }
394             }
395           }
396         }
397       }
398       if (!already_computed) {
399         auto ret = input_infos[i]->ComputeThreshold();
400         if (ret != RET_OK) {
401           MS_LOG(ERROR) << "ComputeThreshold failed.";
402           return ret;
403         }
404       }
405     }
406   }
407   return RET_OK;
408 }
409 
UpdateDivergInterval(std::unordered_map<std::string,std::vector<std::unique_ptr<DivergInfo>>> * diverg_info)410 STATUS Calibrator::UpdateDivergInterval(
411   std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info) {
412   MS_ASSERT(diverg_info != nullptr);
413   for (auto &kv : *diverg_info) {
414     for (auto &info : kv.second) {
415       info->UpdateInterval();
416     }
417   }
418   return RET_OK;
419 }
420 
UpdateDataFrequency(const vector<float> & data,const std::unique_ptr<DivergInfo> & diverg_info)421 STATUS Calibrator::UpdateDataFrequency(const vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
422   MS_ASSERT(diverg_info != nullptr);
423   return diverg_info->UpdateHistogram(data);
424 }
425 
AddQuantizedOp(const CNodePtr & cnode)426 STATUS Calibrator::AddQuantizedOp(const CNodePtr &cnode) {
427   if (cnode == nullptr) {
428     MS_LOG(ERROR) << "To be quantized cnode is null";
429     return RET_ERROR;
430   }
431   string node_name = cnode->fullname_with_scope();
432   std::unique_ptr<DivergInfo> input_diverg = std::make_unique<DivergInfo>(
433     cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
434   MS_CHECK_TRUE_MSG(input_diverg != nullptr, RET_NULL_PTR, "input_diverg is nullptr.");
435   std::unique_ptr<DivergInfo> output_diverg = std::make_unique<DivergInfo>(
436     cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
437   MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
438   inputs_diverg_info_[node_name].push_back(std::move(input_diverg));
439   outputs_diverg_info_[node_name].push_back(std::move(output_diverg));
440   return RET_OK;
441 }
442 
GenerateInputData(const std::string & input_name,size_t image_index,mindspore::tensor::MSTensor * tensor) const443 STATUS Calibrator::GenerateInputData(const std::string &input_name, size_t image_index,
444                                      mindspore::tensor::MSTensor *tensor) const {
445   return preprocess::PreProcess(data_pre_process_param_, input_name, image_index, tensor);
446 }
447 
FullQuantQuantizer(FuncGraphPtr graph,int bit_num,TypeId target_type,bool per_channel)448 FullQuantQuantizer::FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type, bool per_channel)
449     : Quantizer(std::move(graph)) {
450   MS_ASSERT(graph != nullptr);
451   this->per_channel_ = per_channel;
452   this->bit_num = bit_num;
453   this->target_type_ = target_type;
454   if (target_type == kNumberTypeInt8) {
455     quant_max = (1 << (this->bit_num - 1)) - 1;  // 127
456     quant_min = -quant_max;                      // -127
457   } else if (target_type == kNumberTypeUInt8) {
458     quant_max = (1 << this->bit_num) - 1;  // 255
459     quant_min = 0;
460   } else {
461     MS_LOG(ERROR) << "unsupported quant value type: " << target_type;
462   }
463   calibrator_ = std::make_unique<Calibrator>(this->bit_num, quant_max, quant_min);
464   if (calibrator_ == nullptr) {
465     MS_LOG(ERROR) << "create calibrator failed!";
466     return;
467   }
468 }
469 
~FullQuantQuantizer()470 FullQuantQuantizer::~FullQuantQuantizer() {
471   delete fp32_session_;
472   delete fp32_model_;
473   delete int8_session_;
474   delete int8_model_;
475 }
476 
SetInOutQuantParam(const AnfNodePtr & input_node,const std::unique_ptr<DivergInfo> & info,const PrimitivePtr & primitive,bool is_input,size_t index) const477 STATUS FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DivergInfo> &info,
478                                               const PrimitivePtr &primitive, bool is_input, size_t index) const {
479   auto quant_param_holder = GetCNodeQuantHolder(primitive);
480   MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
481   schema::QuantParamT quant_param;
482   TypeId type_id = kTypeUnknown;
483   if (opt::GetDataTypeFromAnfNode(input_node, &type_id) != RET_OK) {
484     MS_LOG(ERROR) << "Get data type failed.";
485     return RET_ERROR;
486   }
487   if (type_id == kNumberTypeFloat32 && info != nullptr) {
488     auto scale = info->GetScale().second;
489     if (scale == 0) {
490       MS_LOG(WARNING) << "The input or output values are very close to 0, so set the scale to 1.";
491       quant_param.scale = 1;
492     } else {
493       quant_param.scale = scale;
494     }
495     quant_param.zeroPoint = info->GetZeropoint().second;
496     quant_param.max = info->max;
497     quant_param.min = info->min;
498     quant_param.numBits = bit_num;
499     quant_param.narrowRange = false;
500     quant_param.inited = true;
501     quant_param.roundType = 1;
502     quant_param.multiplier = 1;
503   } else {
504     quant_param.inited = false;
505   }
506   std::vector<schema::QuantParamT> quant_params = {quant_param};
507   if (is_input) {
508     quant_param_holder->set_input_quant_param(index, quant_params);
509   } else {
510     quant_param_holder->set_output_quant_param(index, quant_params);
511   }
512   return RET_OK;
513 }
514 
DoWeightQuant(const std::string & op_name,const AnfNodePtr & weight,const PrimitivePtr & primitive,bool per_channel,int input_index) const515 STATUS FullQuantQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight,
516                                          const PrimitivePtr &primitive, bool per_channel, int input_index) const {
517   MS_ASSERT(weight != nullptr);
518   MS_ASSERT(primitive != nullptr);
519   // perlayer
520   if (!weight->isa<Parameter>()) {
521     MS_LOG(ERROR) << "not a parameter";
522     return RET_PARAM_INVALID;
523   }
524   auto parameter = weight->cast<ParameterPtr>();
525   if (parameter == nullptr) {
526     MS_LOG(ERROR) << weight->fullname_with_scope() << " can not cast to Parameter";
527     return RET_NULL_PTR;
528   }
529   auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
530   if (tensor_info == nullptr) {
531     MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
532     return RET_NULL_PTR;
533   }
534   auto bit_num_t = bit_num;
535   auto quant_max_t = quant_max;
536   auto quant_min_t = quant_min;
537   auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
538   auto status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_QUANT_ALL, quant_max_t, quant_min_t,
539                                             bit_num_t, weight_quant_type, kNumberTypeInt8, input_index - 1);
540   if (status != RET_OK) {
541     MS_LOG(ERROR) << "QuantFilter failed: " << status;
542     return status;
543   }
544   // set dtype
545   auto abstractBase = parameter->abstract();
546   if (abstractBase == nullptr) {
547     MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << parameter->name();
548     return RET_NULL_PTR;
549   }
550   if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
551     MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << parameter->name();
552     return RET_ERROR;
553   }
554   auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
555   if (abstractTensor == nullptr || abstractTensor->element() == nullptr) {
556     MS_LOG(ERROR) << "abstractTensor is nullptr, " << parameter->name();
557     return RET_NULL_PTR;
558   }
559   abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt8));
560   return RET_OK;
561 }
562 
DoBiasQuant(const AnfNodePtr & bias,const PrimitivePtr & primitive)563 STATUS FullQuantQuantizer::DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive) {
564   if (primitive == nullptr || bias == nullptr) {
565     MS_LOG(ERROR) << "null pointer!";
566     return RET_NULL_PTR;
567   }
568   auto bias_parameter_ptr = bias->cast<ParameterPtr>();
569   MS_ASSERT(bias_parameter_ptr != nullptr);
570   auto bias_default_param = bias_parameter_ptr->default_param();
571   auto bias_param = bias_default_param->cast<tensor::TensorPtr>();
572   MS_ASSERT(bias_parameter != nullptr);
573   auto quant_param_holder = GetCNodeQuantHolder(primitive);
574   MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
575   auto active_weight_quant_params = quant_param_holder->get_input_quant_params();
576 
577   auto active_params = active_weight_quant_params.at(FIRST_INPUT);
578   auto weight_params = active_weight_quant_params.at(SECOND_INPUT);
579 
580   vector<double> input_scales;
581   vector<double> filter_scales;
582   vector<double> bias_scales;
583   size_t sizeX = active_params.size();
584   for (size_t i = 0; i < sizeX; i++) {
585     input_scales.emplace_back(active_params[i].scale);
586   }
587   size_t sizeY = weight_params.size();
588   if (sizeX != sizeY) {
589     if (sizeX > 1 && sizeY > 1) {
590       MS_LOG(ERROR) << "input and filter's scale count cannot match!";
591       return RET_ERROR;
592     }
593   }
594   for (size_t i = 0; i < sizeY; i++) {
595     filter_scales.emplace_back(weight_params[i].scale);
596   }
597   size_t size = std::max(sizeX, sizeY);
598   for (size_t i = 0; i < size; i++) {
599     auto scaleX = sizeX > 1 ? input_scales[i] : input_scales[0];
600     auto scaleY = sizeY > 1 ? filter_scales[i] : filter_scales[0];
601     bias_scales.push_back(scaleX * scaleY);
602   }
603   MS_ASSERT(!bias_scales.empty());
604   size_t shape_size = bias_param->DataSize();
605 
606   // set bias quant param
607   std::vector<schema::QuantParamT> quant_params;
608   for (double bias_scale : bias_scales) {
609     schema::QuantParamT quant_param;
610     if (bias_scale == 0) {
611       MS_LOG(WARNING) << "bias_scale == 0";
612       quant_param.scale = 1;
613     } else {
614       quant_param.scale = bias_scale;
615     }
616     quant_param.zeroPoint = 0;
617     quant_param.inited = true;
618     quant_params.emplace_back(quant_param);
619   }
620   // quant bias data
621   std::vector<int32_t> quant_datas(shape_size);
622 
623   auto *raw_datas = static_cast<float *>(bias_param->data_c());
624   if (ComputeBiasDataAndQuantParam(bias_scales, input_scales, raw_datas, quant_param_holder, &quant_params,
625                                    &quant_datas) != RET_OK) {
626     MS_LOG(ERROR) << "compute bias data failed.";
627     return RET_ERROR;
628   }
629   quant_param_holder->set_input_quant_param(THIRD_INPUT, quant_params);
630   auto ret = SetTensorData(bias_param, quant_datas.data(), shape_size * sizeof(int32_t));
631   if (ret != RET_OK) {
632     MS_LOG(ERROR) << "set tensor data failed.";
633     return RET_ERROR;
634   }
635   // set dtype
636   auto abstractBase = bias_parameter_ptr->abstract();
637   if (abstractBase == nullptr) {
638     MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << bias_parameter_ptr->name();
639     return RET_ERROR;
640   }
641   if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
642     MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << bias_parameter_ptr->name();
643     return RET_ERROR;
644   }
645   auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
646   if (abstractTensor == nullptr || abstractTensor->element() == nullptr) {
647     MS_LOG(ERROR) << "abstractTensor is nullptr" << bias_parameter_ptr->name();
648     return RET_NULL_PTR;
649   }
650   abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
651   return RET_OK;
652 }
653 
DoParameterNodeQuant(const CNodePtr & cnode,const AnfNodePtr & input_node,size_t input_index)654 STATUS FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const AnfNodePtr &input_node,
655                                                 size_t input_index) {
656   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
657   if (primitive == nullptr) {
658     return RET_ERROR;
659   }
660   auto op_name = cnode->fullname_with_scope();
661   STATUS ret;
662   TypeId type_id = kTypeUnknown;
663   if (opt::GetDataTypeFromAnfNode(input_node, &type_id) != RET_OK) {
664     MS_LOG(ERROR) << "Get data type failed.";
665     return RET_ERROR;
666   }
667   // support for share weight.
668   if (type_id == kNumberTypeInt8) {
669     return RET_CONTINUE;
670   }
671   if (type_id != kNumberTypeFloat32) {
672     ret = SetInOutQuantParam(input_node, nullptr, primitive, true, input_index - 1);
673     if (ret != RET_OK) {
674       MS_LOG(ERROR) << "Set In/Out quant param failed.";
675       return ret;
676     }
677     return RET_QUANT_CONTINUE;
678   }
679   if (CheckNodeInSet(cnode, has_bias_operator)) {
680     if (input_index == FOURTH_INPUT) {
681       ret = DoBiasQuant(input_node, primitive);
682       if (ret != RET_OK) {
683         MS_LOG(ERROR) << "Do bias quant failed.";
684         return ret;
685       }
686     } else {
687       if (opt::CheckPrimitiveType(cnode, prim::kPrimMatMul)) {
688         ret = DoWeightQuant(op_name, input_node, primitive, false, input_index);
689       } else {
690         ret = DoWeightQuant(op_name, input_node, primitive, true, input_index);
691       }
692       if (ret != RET_OK) {
693         MS_LOG(ERROR) << "Do bias quant failed.";
694         return ret;
695       }
696     }
697   } else {
698     ret = DoWeightQuant(op_name, input_node, primitive, false, input_index);
699     if (ret != RET_OK) {
700       MS_LOG(ERROR) << "Do bias quant failed.";
701       return ret;
702     }
703   }
704   return RET_OK;
705 }
706 
QuantNodeSimpleOp(const CNodePtr & cnode)707 STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
708   MS_ASSERT(cnode != nullptr);
709   auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
710   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
711   if (primitive == nullptr) {
712     return RET_ERROR;
713   }
714   auto op_name = cnode->fullname_with_scope();
715   auto primitive_quant_holder = GetCNodeQuantHolder(primitive);
716   MS_CHECK_TRUE_MSG(primitive_quant_holder != nullptr, RET_NULL_PTR, "primitive_quant_holder is nullptr.");
717   size_t activation_input_index = 0;
718   STATUS ret;
719   for (size_t i = 1; i < cnode->inputs().size(); i++) {
720     auto input_node = cnode->input(i);
721     MS_ASSERT(input_node != nullptr);
722     bool is_graph_input = false;
723     if (input_node->isa<Parameter>()) {
724       if (!input_node->cast<ParameterPtr>()->has_default()) {
725         is_graph_input = true;
726       }
727     }
728     if (is_graph_input) {
729       // do input quant
730       auto &info = (*inputs_diverg_info)[op_name][activation_input_index++];
731       ret = SetInOutQuantParam(input_node, info, primitive, true, i - 1);
732       if (ret != RET_OK) {
733         MS_LOG(ERROR) << "Set activation quant failed.";
734         return ret;
735       }
736     } else if (input_node->isa<mindspore::CNode>()) {
737       auto input_cnode = input_node->cast<mindspore::CNodePtr>();
738       auto input_cnode_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
739       if (input_cnode_primitive == nullptr) {
740         MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
741                       << " Primitive is null";
742         continue;
743       }
744       auto input_primitive_quant_holder = GetCNodeQuantHolder(input_cnode_primitive);
745       MS_CHECK_TRUE_MSG(input_primitive_quant_holder != nullptr, RET_NULL_PTR,
746                         "input_primitive_quant_holder is nullptr.");
747       if (input_primitive_quant_holder->IsOutputQuantParamsInited()) {
748         auto quant_param = input_primitive_quant_holder->get_output_quant_params().front();
749         primitive_quant_holder->set_input_quant_param(i - 1, quant_param);
750         activation_input_index++;
751       } else {
752         // do input quant
753         auto &info = (*inputs_diverg_info)[op_name][activation_input_index++];
754         ret = SetInOutQuantParam(input_node, info, primitive, true, i - 1);
755         if (ret != RET_OK) {
756           MS_LOG(ERROR) << "Set activation quant failed.";
757           return ret;
758         }
759       }
760     } else if (input_node->isa<mindspore::Parameter>()) {
761       ret = DoParameterNodeQuant(cnode, input_node, i);
762       if (ret == RET_QUANT_CONTINUE) {
763         continue;
764       } else if (ret != RET_OK) {
765         MS_LOG(ERROR) << "Do parameter node quant failed.";
766         return ret;
767       }
768     } else {
769       MS_LOG(ERROR) << input_node->fullname_with_scope() << ":" << input_node->type_name() << " is not support type";
770       return RET_ERROR;
771     }
772   }
773   return RET_OK;
774 }
775 
QuantNode()776 STATUS FullQuantQuantizer::QuantNode() {
777   auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
778   auto outputs_diverg_info = calibrator_->GetOutputDivergInfo();
779 
780   auto cnodes = funcGraph->GetOrderedCnodes();
781   for (auto &cnode : cnodes) {
782     auto op_name = cnode->fullname_with_scope();
783     auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
784     if (primitive == nullptr) {
785       MS_LOG(ERROR) << "primitive is nullptr";
786       continue;
787     }
788     auto primitive_quant_holder = GetCNodeQuantHolder(primitive);
789     MS_CHECK_TRUE_MSG(primitive_quant_holder != nullptr, RET_NULL_PTR, "primitive_quant_holder is nullptr.");
790     if (inputs_diverg_info->find(op_name) == inputs_diverg_info->end()) {
791       MS_LOG(INFO) << op_name << " can not do quant";
792       primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_NONE);
793       continue;
794     }
795 
796     auto op_type = primitive->name();
797     MS_LOG(DEBUG) << "OpName: " << op_name;
798     if (op_type == lite::kNameTupleGetItem) {
799       constexpr int tuple_get_item_input_size = 3;
800       MS_CHECK_TRUE_MSG(cnode->size() == tuple_get_item_input_size, RET_ERROR, "cnode->size() != 3");
801       auto index_node = cnode->input(THIRD_INPUT);
802       auto index_value_node = index_node->cast<mindspore::ValueNodePtr>();
803       if (index_value_node == nullptr) {
804         MS_LOG(WARNING) << "index value node is null";
805         continue;
806       }
807       size_t index = opt::CastToInt(index_value_node->value()).front();
808       auto input_node = cnode->input(SECOND_INPUT);
809       MS_CHECK_TRUE_MSG(input_node != nullptr, RET_ERROR, "input_node == nullptr");
810       auto input_cnode = input_node->cast<mindspore::CNodePtr>();
811       MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_ERROR, "input_cnode == nullptr");
812       auto input_cnode_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
813       if (input_cnode_primitive == nullptr) {
814         MS_LOG(WARNING) << "input_cnode_primitive is null";
815         continue;
816       }
817       auto input_primitive_quant_holder = GetCNodeQuantHolder(input_cnode_primitive);
818       MS_CHECK_TRUE_MSG(input_primitive_quant_holder != nullptr, RET_NULL_PTR,
819                         "input_primitive_quant_holder is nullptr.");
820 
821       if (input_primitive_quant_holder->get_output_quant_params().size() > index) {
822         auto quant_param = input_primitive_quant_holder->get_output_quant_params()[index];
823         primitive_quant_holder->set_input_quant_param(0, quant_param);
824         primitive_quant_holder->set_output_quant_param(0, quant_param);
825       } else {
826         MS_LOG(WARNING) << "this TupleGetItem node's input node: " << input_cnode->fullname_with_scope()
827                         << "'s output quant_params size: "
828                         << input_primitive_quant_holder->get_output_quant_params().size() << ", but index: " << index;
829       }
830       primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
831       continue;
832     } else {  // do simple op quant
833       auto status = QuantNodeSimpleOp(cnode);
834       if (status != RET_OK) {
835         MS_LOG(ERROR) << "simple op quant failed.";
836         return status;
837       }
838     }
839     // do output quant, there may multi-output
840     auto &infos = (*outputs_diverg_info)[op_name];
841     for (size_t index = 0; index < infos.size(); index++) {
842       auto &info = infos.at(index);
843       auto ret = SetInOutQuantParam(cnode, info, primitive, false, index);
844       if (ret != RET_OK) {
845         MS_LOG(ERROR) << "Set In/Out quant param failed.";
846         return ret;
847       }
848       primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
849     }
850   }
851   return RET_OK;
852 }
853 
UpdateDivergeInterval()854 STATUS FullQuantQuantizer::UpdateDivergeInterval() {
855   auto ret = this->calibrator_->UpdateDivergInterval(this->calibrator_->GetInputDivergInfo());
856   if (ret != RET_OK) {
857     MS_LOG(ERROR) << "Update input diverge interval failed.";
858     return ret;
859   }
860   ret = this->calibrator_->UpdateDivergInterval(this->calibrator_->GetOutputDivergInfo());
861   if (ret != RET_OK) {
862     MS_LOG(ERROR) << "Update output diverge interval failed.";
863     return ret;
864   }
865   return RET_OK;
866 }
867 
868 /**
869  *  Mark quantifiable nodes
870  **/
PreProcess()871 STATUS FullQuantQuantizer::PreProcess() {
872   auto cnodes = funcGraph->GetOrderedCnodes();
873   for (auto &cnode : cnodes) {
874     AnfNodePtr anf = cnode->cast<AnfNodePtr>();
875     if (anf == nullptr) {
876       MS_LOG(ERROR) << " cnode is null";
877       return RET_NULL_PTR;
878     }
879     if (mindspore::lite::quant::QuantStrategy::CanOpFullQuantized(anf)) {
880       auto ret = calibrator_->AddQuantizedOp(cnode);
881       if (ret != RET_OK) {
882         MS_LOG(ERROR) << "Add Quantized Op failed.";
883         return ret;
884       }
885     }
886     auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
887     if (primitive == nullptr) {
888       MS_LOG(ERROR) << cnode->fullname_with_scope() << " primitive is null";
889       continue;
890     }
891     auto quant_param_holder = GetCNodeQuantHolder(primitive);
892     MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
893     quant_param_holder->ClearInputOutputQuantParam();
894   }
895   return RET_OK;
896 }
897 
CheckFp32TensorVec(const std::string & node_name,const std::vector<mindspore::tensor::MSTensor * > & tensor_vec)898 STATUS FullQuantQuantizer::CheckFp32TensorVec(const std::string &node_name,
899                                               const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) {
900   MS_ASSERT(tensor_vec != nullptr);
901   if (tensor_vec.empty()) {
902     MS_LOG(ERROR) << "node: " << node_name << " input tensors is 0";
903     return RET_ERROR;
904   }
905   auto *tensor = tensor_vec[0];
906   MS_ASSERT(tensor != nullptr);
907   if (tensor->data_type() != kNumberTypeFloat32) {
908     MS_LOG(INFO) << "node: " << node_name << " will not quantize"
909                  << " tensor data_type: " << tensor->data_type();
910     return RET_ERROR;
911   }
912   return RET_OK;
913 }
914 
915 /**
916  * 1. create input tensor
917  * 2. insert callback to session
918  * 3. run session
919  **/
DoInference()920 STATUS FullQuantQuantizer::DoInference() {
921   // get input tensor
922   vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
923   if (inputs.size() != calibrator_->GetInputNum()) {
924     MS_LOG(ERROR) << "model's input tensor count: " << inputs.size() << " != "
925                   << " calibrator count:" << calibrator_->GetInputNum();
926     return RET_ERROR;
927   }
928 
929   for (size_t calib_index = 0; calib_index < calibrator_->GetBatchNum(); calib_index++) {
930     // set multi-input data
931     for (size_t input_index = 0; input_index < inputs.size(); input_index++) {
932       STATUS status =
933         calibrator_->GenerateInputData(inputs[input_index]->tensor_name(), calib_index, inputs[input_index]);
934       if (status != RET_OK) {
935         MS_LOG(ERROR) << "generate input data from images failed!";
936         return RET_ERROR;
937       }
938     }
939 
940     KernelCallBack beforeCallBack = [&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
941                                         const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
942                                         const CallBackParam &callParam) -> bool {
943       auto diverg_info_map = calibrator_->GetInputDivergInfo();
944       if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
945         return true;
946       }
947       if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeOutputs) != RET_OK) {
948         return true;
949       }
950       bool is_init = beforeInputs.size() > 1 && (*diverg_info_map)[callParam.node_name].size() == 1;
951       if (is_init) {
952         for (size_t i = 1; i < beforeInputs.size(); i++) {
953           if (beforeInputs.at(i)->data_type() != kNumberTypeFloat32 || beforeInputs.at(i)->IsConst()) {
954             continue;
955           }
956           auto input_diverg = std::make_unique<DivergInfo>();
957           MS_CHECK_TRUE_MSG(input_diverg != nullptr, false, "input_diverg is nullptr.");
958           *input_diverg = *((*diverg_info_map)[callParam.node_name][0]);
959           (*diverg_info_map)[callParam.node_name].push_back(std::move(input_diverg));
960         }
961       }
962       for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
963         auto tensor = beforeInputs[i];
964         MS_ASSERT(tensor != nullptr);
965         const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
966         MS_CHECK_TRUE_MSG(tensor_data != nullptr, false, "tensor_data is nullptr.");
967         size_t elem_count = tensor->ElementsNum();
968         vector<float> data(tensor_data, tensor_data + elem_count);
969         auto ret = this->calibrator_->RecordMaxMinValue(data, (*diverg_info_map)[callParam.node_name][i]);
970         MS_CHECK_TRUE_MSG(ret == RET_OK, false, "Record MaxMinValue failed!");
971       }
972       return true;
973     };
974     // func
975     KernelCallBack afterCallBack = [&](const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
976                                        const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
977                                        const CallBackParam &callParam) -> bool {
978       auto diverg_info_map = calibrator_->GetOutputDivergInfo();
979       if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
980         return true;
981       }
982       if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
983         return true;
984       }
985       bool is_init = afterOutputs.size() > 1 && (*diverg_info_map)[callParam.node_name].size() == 1;
986       if (is_init) {
987         for (size_t i = 1; i < afterOutputs.size(); i++) {
988           auto output_diverg = std::make_unique<DivergInfo>();
989           CHECK_NULL_RETURN(output_diverg);
990           *output_diverg = *((*diverg_info_map)[callParam.node_name][0]);
991           (*diverg_info_map)[callParam.node_name].push_back(std::move(output_diverg));
992         }
993       }
994       size_t output_i = 0;
995       for (const auto &tensor : afterOutputs) {
996         const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
997         CHECK_NULL_RETURN(tensor_data);
998         size_t elem_count = tensor->ElementsNum();
999         vector<float> data(tensor_data, tensor_data + elem_count);
1000         auto ret = this->calibrator_->RecordMaxMinValue(data, (*diverg_info_map)[callParam.node_name][output_i]);
1001         MS_CHECK_TRUE_MSG(ret == RET_OK, false, "Record MaxMinValue failed!");
1002         output_i++;
1003       }
1004       return true;
1005     };
1006     auto status = fp32_session_->RunGraph(beforeCallBack, afterCallBack);
1007     if (status != RET_OK) {
1008       MS_LOG(ERROR) << "run model failed!";
1009       return RET_ERROR;
1010     }
1011   }
1012   return RET_OK;
1013 }
1014 
Int8Inference()1015 STATUS FullQuantQuantizer::Int8Inference() {
1016   // int8 inference
1017   vector<mindspore::tensor::MSTensor *> inputs = int8_session_->GetInputs();
1018   for (auto input_tensor : inputs) {
1019     // get input tensor
1020     auto elem_count = input_tensor->ElementsNum();
1021     vector<float> dummy_data(elem_count);
1022     // set the input data to 0.1
1023     std::fill(dummy_data.begin(), dummy_data.end(), 0.1);
1024     auto ret =
1025       memcpy_s(input_tensor->MutableData(), input_tensor->Size(), dummy_data.data(), sizeof(float) * dummy_data.size());
1026     if (ret != EOK) {
1027       MS_LOG(ERROR) << "memcpy_s error: " << ret;
1028       return RET_ERROR;
1029     }
1030   }
1031 
1032   for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
1033     // before func
1034     KernelCallBack before_call_back = GetBeforeCallBack(true);
1035     // after func
1036     KernelCallBack after_call_back = GetAfterCallBack(true);
1037     auto ret = int8_session_->RunGraph(before_call_back, after_call_back);
1038     if (ret != RET_OK) {
1039       MS_LOG(ERROR) << "run model failed!";
1040       return RET_ERROR;
1041     }
1042   }  // end for images
1043   return RET_OK;
1044 }
1045 
BiasCorrection(const FuncGraphPtr & func_graph)1046 STATUS FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
1047   std::future<STATUS> int8_inference = std::async(std::launch::async, &FullQuantQuantizer::Int8Inference, this);
1048   // get input tensor
1049   vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
1050   if (inputs.size() != 1) {
1051     MS_LOG(ERROR) << "model's input tensor size: " << inputs.size();
1052     return RET_ERROR;
1053   }
1054   // fp32 inference
1055   for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
1056     for (size_t input_index = 0; input_index < inputs.size(); input_index++) {
1057       STATUS status = calibrator_->GenerateInputData(inputs[input_index]->tensor_name(), i, inputs[input_index]);
1058       if (status != RET_OK) {
1059         MS_LOG(ERROR) << "generate input data from images failed!";
1060         return RET_ERROR;
1061       }
1062     }
1063     // before func
1064     KernelCallBack before_call_back = GetBeforeCallBack(false);
1065     // after func
1066     KernelCallBack after_call_back = GetAfterCallBack(false);
1067     auto status = fp32_session_->RunGraph(before_call_back, after_call_back);
1068     if (status != RET_OK) {
1069       MS_LOG(ERROR) << "run model failed!";
1070       return RET_ERROR;
1071     }
1072   }  // end for images
1073 
1074   STATUS status = int8_inference.get();
1075   if (status != RET_OK) {
1076     MS_LOG(ERROR) << "int8 inference failed!";
1077     return RET_ERROR;
1078   }
1079   if (calibrator_->GetBatchNum() == 0) {
1080     MS_LOG(ERROR) << "divisor 'calibrate_size' cannot be 0.";
1081     return RET_ERROR;
1082   }
1083   for (auto &key_value : op_bias_diff_map) {
1084     std::for_each(key_value.second.begin(), key_value.second.end(),
1085                   [this](float &data) { data = data / calibrator_->GetBatchNum(); });
1086   }
1087   auto cnodes = func_graph->GetOrderedCnodes();
1088   for (auto &cnode : cnodes) {
1089     auto op_name = cnode->fullname_with_scope();
1090     if (op_bias_diff_map.find(op_name) == op_bias_diff_map.end()) {
1091       continue;
1092     }
1093     status = BiasCorrection(func_graph, cnode);
1094     if (status != RET_OK) {
1095       MS_LOG(ERROR) << "do node bias correct failed.";
1096       break;
1097     }
1098   }
1099   return status;
1100 }
1101 
BiasCorrection(const FuncGraphPtr & func_graph,const CNodePtr & cnode)1102 STATUS FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
1103   auto op_name = cnode->fullname_with_scope();
1104   const auto &bias_diff = op_bias_diff_map[op_name];
1105   auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
1106   if (primitive == nullptr) {
1107     MS_LOG(ERROR) << "primitive is nullptr";
1108     return RET_NULL_PTR;
1109   }
1110   auto quant_param_holder = GetCNodeQuantHolder(primitive);
1111   MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
1112   auto input_quant_params = quant_param_holder->get_input_quant_params();
1113   if (input_quant_params.size() == DIMENSION_3D) {
1114     // compensate the existed
1115     auto bias_quant_params = input_quant_params.at(THIRD_INPUT);
1116     auto bias = cnode->input(FOURTH_INPUT);
1117     auto bias_parameter_ptr = bias->cast<ParameterPtr>();
1118     auto bias_default_param = bias_parameter_ptr->default_param();
1119     auto bias_param = bias_default_param->cast<tensor::TensorPtr>();
1120     int *bias_datas = static_cast<int *>(bias_param->data_c());
1121 
1122     if (static_cast<size_t>(bias_param->DataSize()) != bias_diff.size()) {
1123       MS_LOG(DEBUG) << "unexpected bias data count: " << bias_param->DataSize()
1124                     << " not the same as bias_diff: " << bias_diff.size();
1125       return RET_ERROR;
1126     }
1127     if (bias_quant_params.size() != bias_diff.size()) {
1128       MS_LOG(ERROR) << "unexpected bias quant params size: " << bias_quant_params.size()
1129                     << " not the same as bias_diff: " << bias_diff.size();
1130       return RET_ERROR;
1131     }
1132     for (int i = 0; i < bias_param->DataSize(); i++) {
1133       auto scale = bias_quant_params[i].scale;
1134       if (fabs(scale) <= 0.0f) {
1135         MS_LOG(ERROR) << "divisor 'scale' cannot be 0.";
1136         return RET_ERROR;
1137       }
1138       double after_correct = std::round(bias_diff[i] / scale) + bias_datas[i];
1139       const constexpr int32_t corrected_bias_abs_limit = 0.6 * INT32_MAX;
1140       if (after_correct > corrected_bias_abs_limit) {
1141         MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too large: " << after_correct
1142                         << " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i] << " scale: " << scale;
1143         bias_datas[i] = static_cast<int>(corrected_bias_abs_limit);
1144       } else if (after_correct < -corrected_bias_abs_limit) {
1145         MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too small: " << after_correct
1146                         << " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i] << " scale: " << scale;
1147         bias_datas[i] = static_cast<int>(-corrected_bias_abs_limit);
1148       } else {
1149         auto diff = static_cast<int>(std::round(bias_diff[i] / scale));
1150         bias_datas[i] += diff;
1151       }
1152     }
1153   } else if (input_quant_params.size() == DIMENSION_2D) {
1154     MS_LOG(INFO) << op_name << " add bias input";
1155     // need to add bias input
1156     auto parameter = func_graph->add_parameter();
1157     if (parameter == nullptr) {
1158       MS_LOG(ERROR) << "parameter is nullptr.";
1159       return RET_NULL_PTR;
1160     }
1161     ShapeVector shape;
1162     shape.push_back(bias_diff.size());
1163 
1164     auto tensor_info = CreateTensorInfo(bias_diff.data(), sizeof(float) * bias_diff.size(), shape, kNumberTypeFloat32);
1165     if (tensor_info == nullptr) {
1166       MS_LOG(ERROR) << "create tensor info failed.";
1167       return RET_ERROR;
1168     }
1169     auto status = InitParameterFromTensorInfo(parameter, tensor_info);
1170     if (status != RET_OK) {
1171       MS_LOG(ERROR) << "init parameter from tensor info failed";
1172       return RET_ERROR;
1173     }
1174     parameter->set_name("added_" + op_name + "_bias");
1175     cnode->add_input(parameter);
1176     status = DoBiasQuant(parameter, primitive);
1177     if (status != RET_OK) {
1178       MS_LOG(ERROR) << "Do bias quant failed.";
1179       return RET_ERROR;
1180     }
1181   } else {
1182     MS_LOG(ERROR) << "unexpected get_input_quant_params size: " << input_quant_params.size();
1183   }
1184   return RET_OK;
1185 }
1186 
CollectDataFrequency()1187 STATUS FullQuantQuantizer::CollectDataFrequency() {
1188   // get input tensor
1189   vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
1190   if (inputs.size() != calibrator_->GetInputNum()) {
1191     MS_LOG(ERROR) << "model's input tensor cnt: " << inputs.size() << " != " << calibrator_->GetInputNum();
1192     return RET_ERROR;
1193   }
1194 
1195   for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
1196     // set multi-input data
1197     for (size_t input_index = 0; input_index < inputs.size(); input_index++) {
1198       STATUS status = calibrator_->GenerateInputData(inputs[input_index]->tensor_name(), i, inputs[input_index]);
1199       if (status != RET_OK) {
1200         MS_LOG(ERROR) << "generate input data from images failed!";
1201         return RET_ERROR;
1202       }
1203     }
1204 
1205     KernelCallBack before_callback = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
1206                                          const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
1207                                          const CallBackParam &callParam) {
1208       auto diverg_info_map = calibrator_->GetInputDivergInfo();
1209       if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
1210         return true;
1211       }
1212       if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, before_inputs) != RET_OK) {
1213         return true;
1214       }
1215       int input_i = 0;
1216       for (auto tensor : before_inputs) {
1217         if (tensor->data_type() != kNumberTypeFloat32 || tensor->IsConst()) {
1218           continue;
1219         }
1220         const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
1221         MS_ASSERT(tensor_data != nullptr);
1222         size_t elem_count = tensor->ElementsNum();
1223         vector<float> data(tensor_data, tensor_data + elem_count);
1224         auto ret = this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[callParam.node_name][input_i++]);
1225         if (ret != RET_OK) {
1226           return false;
1227         }
1228       }
1229       return true;
1230     };
1231 
1232     KernelCallBack after_callBack = [&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
1233                                         const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
1234                                         const CallBackParam &call_param) {
1235       auto diverg_info_map = calibrator_->GetOutputDivergInfo();
1236       if (diverg_info_map->find(call_param.node_name) == diverg_info_map->end()) {
1237         return true;
1238       }
1239       if (FullQuantQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) {
1240         return true;
1241       }
1242       int output_i = 0;
1243       // all outputs are same dtype.
1244       for (const auto &tensor : after_outputs) {
1245         const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
1246         MS_ASSERT(tensor_data != nullptr);
1247         size_t elem_count = tensor->ElementsNum();
1248         vector<float> data(tensor_data, tensor_data + elem_count);
1249         auto ret = this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[call_param.node_name][output_i++]);
1250         if (ret != RET_OK) {
1251           return false;
1252         }
1253       }
1254       return true;
1255     };
1256     auto status = fp32_session_->RunGraph(before_callback, after_callBack);
1257     if (status != RET_OK) {
1258       MS_LOG(ERROR) << "run model failed!";
1259       return RET_ERROR;
1260     }
1261   }
1262 
1263   return RET_OK;
1264 }
1265 
ComputeThreshold()1266 STATUS FullQuantQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); }
1267 
DoQuantize(FuncGraphPtr func_graph)1268 STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
1269   MS_LOG(INFO) << "start to parse config file";
1270   if (this->calibrator_ == nullptr) {
1271     MS_LOG(ERROR) << "calibrator is null!";
1272     return RET_ERROR;
1273   }
1274   calibrator_->full_quant_param_ = flags.fullQuantParam;
1275   calibrator_->data_pre_process_param_ = flags.dataPreProcessParam;
1276   if (flags.dataPreProcessParam.calibrate_path.empty()) {
1277     MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
1278     return RET_INPUT_PARAM_INVALID;
1279   }
1280   if (flags.dataPreProcessParam.calibrate_size <= kMinSize || flags.dataPreProcessParam.calibrate_size > kMaxSize) {
1281     MS_LOG(ERROR) << "calibrate size must pass and the size should in [1, 65535].";
1282     return RET_INPUT_PARAM_INVALID;
1283   }
1284   if (flags.dataPreProcessParam.input_type == preprocess::INPUT_TYPE_MAX) {
1285     MS_LOG(ERROR) << "input_type must pass IMAGE | BIN.";
1286     return RET_INPUT_PARAM_INVALID;
1287   }
1288   STATUS status = PreProcess();
1289   if (status != RET_OK) {
1290     MS_LOG(ERROR) << "do pre process failed!";
1291     return status;
1292   }
1293 
1294   // anf -- fb
1295   flags.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
1296   MS_LOG(INFO) << "start create session";
1297   auto sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
1298   fp32_session_ = sm.session;
1299   fp32_model_ = sm.model;
1300   if (fp32_session_ == nullptr || fp32_model_ == nullptr) {
1301     MS_LOG(ERROR) << "create session failed!";
1302     return RET_ERROR;
1303   }
1304   MS_LOG(INFO) << "start to update divergence's max value";
1305   status = DoInference();
1306   if (status != RET_OK) {
1307     return status;
1308   }
1309   MS_LOG(INFO) << "start to update divergence's interval";
1310   status = UpdateDivergeInterval();
1311   if (status != RET_OK) {
1312     return status;
1313   }
1314   MS_LOG(INFO) << "start to collect data's distribution";
1315   status = CollectDataFrequency();
1316   if (status != RET_OK) {
1317     return status;
1318   }
1319   MS_LOG(INFO) << "compute the best threshold";
1320   status = ComputeThreshold();
1321   if (status != RET_OK) {
1322     return status;
1323   }
1324   MS_LOG(INFO) << "start to generate quant param and quantize tensor's data";
1325   status = QuantNode();
1326   if (status != RET_OK) {
1327     return status;
1328   }
1329 
1330   // add quant_cast
1331   quant::QuantCast quant_cast;
1332   status = quant_cast.Run(func_graph);
1333   if (status != RET_OK) {
1334     MS_LOG(ERROR) << "add QuantCast error";
1335     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
1336     return RET_ERROR;
1337   }
1338 
1339   if (calibrator_->GetBiasCorrection()) {
1340     // init in8 session
1341     MS_LOG(INFO) << "create quant session";
1342     flags.commonQuantParam.quant_type = schema::QuantType_QUANT_ALL;
1343     auto int8_sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
1344     int8_session_ = int8_sm.session;
1345     int8_model_ = int8_sm.model;
1346     if (int8_session_ == nullptr || int8_model_ == nullptr) {
1347       MS_LOG(ERROR) << "create session failed!";
1348       return RET_ERROR;
1349     }
1350     MS_LOG(INFO) << "do bias correction";
1351     status = BiasCorrection(func_graph);
1352     if (status != RET_OK) {
1353       MS_LOG(WARNING) << "BiasCorrection failed.";
1354     }
1355   }
1356   return RET_OK;
1357 }
1358 
OpInputDataHandle(OperationType type,const string & op_name,std::vector<float> * data)1359 bool FullQuantQuantizer::OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
1360   MS_ASSERT(data != nullptr);
1361   std::lock_guard<std::mutex> lg(mutex_op_input);
1362   if (type == STORE) {
1363     if (fp32_op_input_map.find(op_name) != fp32_op_input_map.end()) {
1364       // the data has not been fetched by int8 model
1365       return false;
1366     }
1367     fp32_op_input_map[op_name] = *data;
1368     return true;
1369   } else if (type == FETCH) {
1370     if (fp32_op_input_map.find(op_name) == fp32_op_input_map.end()) {
1371       // the data not generated by fp32 model yet
1372       return false;
1373     }
1374     *data = fp32_op_input_map[op_name];
1375     fp32_op_input_map.erase(op_name);
1376     return true;
1377   } else {
1378     MS_LOG(ERROR) << "unexpected type: " << type;
1379   }
1380   return false;
1381 }
1382 
OpOutputChMeanDataHandle(OperationType type,const string & op_name,std::vector<float> * data)1383 bool FullQuantQuantizer::OpOutputChMeanDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
1384   MS_ASSERT(data != nullptr);
1385   std::lock_guard<std::mutex> lg(mutex_op_output);
1386   if (type == STORE) {
1387     if (fp32_op_output_ch_mean_map.find(op_name) != fp32_op_output_ch_mean_map.end()) {
1388       // the data has not been fetched by int8 model
1389       return false;
1390     }
1391     fp32_op_output_ch_mean_map[op_name] = *data;
1392     return true;
1393   } else if (type == FETCH) {
1394     if (fp32_op_output_ch_mean_map.find(op_name) == fp32_op_output_ch_mean_map.end()) {
1395       // the data not generated by fp32 model yet
1396       return false;
1397     }
1398     *data = fp32_op_output_ch_mean_map[op_name];
1399     fp32_op_output_ch_mean_map.erase(op_name);
1400     return true;
1401   } else {
1402     MS_LOG(ERROR) << "unexpected type: " << type;
1403   }
1404   return false;
1405 }
1406 
GetBeforeCallBack(bool int8_op)1407 KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
1408   KernelCallBack before_call_back;
1409   if (!int8_op) {
1410     before_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
1411                               const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
1412                               const CallBackParam &callParam) -> bool {
1413       if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
1414         if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, before_inputs) != RET_OK) {
1415           return true;
1416         }
1417         auto tensor = before_inputs[0];
1418         MS_ASSERT(tensor != nullptr);
1419         size_t elem_count = tensor->ElementsNum();
1420         std::vector<float> fp32_op_input(elem_count);
1421         auto ret =
1422           memcpy_s(fp32_op_input.data(), fp32_op_input.size() * sizeof(float), tensor->MutableData(), tensor->Size());
1423         if (ret != EOK) {
1424           MS_LOG(ERROR) << "memcpy error: " << ret;
1425           return false;
1426         }
1427         while (!OpInputDataHandle(STORE, callParam.node_name, &fp32_op_input)) {
1428           std::this_thread::sleep_for(std::chrono::milliseconds(kMillisecondsBase));
1429         }
1430       }
1431       return true;
1432     };
1433   } else {
1434     before_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
1435                               const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
1436                               const CallBackParam &callParam) -> bool {
1437       if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
1438         vector<float> fp32_op_input;
1439         while (!OpInputDataHandle(FETCH, callParam.node_name, &fp32_op_input)) {
1440           std::this_thread::sleep_for(std::chrono::milliseconds(kMillisecondsBase));
1441         }
1442         auto tensor = before_inputs[0];
1443         MS_ASSERT(tensor != nullptr);
1444         if (tensor->data_type() != kNumberTypeInt8) {
1445           MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
1446           return false;
1447         }
1448         // do quantization: activation is always per layer quantized
1449         std::vector<int8_t> quant_datas;
1450         auto quant_params = tensor->quant_params();
1451         if (quant_params.size() != 1) {
1452           MS_LOG(ERROR) << "unexpected quant_params size: " << quant_params.size();
1453           return false;
1454         }
1455         schema::QuantParamT quant_param_t;
1456         quant_param_t.scale = quant_params[0].scale;
1457         quant_param_t.zeroPoint = quant_params[0].zeroPoint;
1458         for (auto float_data : fp32_op_input) {
1459           auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, quant_max, quant_min);
1460           quant_datas.push_back(quant_data);
1461         }
1462 
1463         if (tensor->Size() != quant_datas.size() * sizeof(int8_t)) {
1464           MS_LOG(ERROR) << "unexpected tensor size: " << quant_datas.size()
1465                         << " not the same with: " << quant_datas.size() * sizeof(int8_t);
1466           return false;
1467         }
1468 
1469         auto ret =
1470           memcpy_s(tensor->MutableData(), tensor->Size(), quant_datas.data(), quant_datas.size() * sizeof(int8_t));
1471         if (ret != EOK) {
1472           MS_LOG(ERROR) << "memcpy error: " << ret;
1473           return false;
1474         }
1475       }
1476       return true;
1477     };
1478   }
1479   return before_call_back;
1480 }
1481 
GetAfterCallBack(bool int8_op)1482 KernelCallBack FullQuantQuantizer::GetAfterCallBack(bool int8_op) {
1483   KernelCallBack after_call_back;
1484   if (!int8_op) {
1485     return GetFloatAfterCallBack();
1486   }
1487   return GetInt8AfterCallBack();
1488 }
1489 
GetInt8AfterCallBack()1490 KernelCallBack FullQuantQuantizer::GetInt8AfterCallBack() {
1491   KernelCallBack after_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
1492                                           const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
1493                                           const CallBackParam &callParam) -> bool {
1494     if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
1495       vector<float> fp32_op_output_ch_mean;
1496       while (!OpOutputChMeanDataHandle(FETCH, callParam.node_name, &fp32_op_output_ch_mean)) {
1497         std::this_thread::sleep_for(std::chrono::milliseconds(kMillisecondsBase));
1498       }
1499       auto tensor = afterOutputs[0];
1500       MS_ASSERT(tensor != nullptr);
1501       if (tensor->data_type() != kNumberTypeInt8) {
1502         MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
1503         return false;
1504       }
1505       const int8_t *tensor_data = static_cast<int8_t *>(tensor->MutableData());
1506       size_t elem_count = tensor->ElementsNum();
1507       auto shapes = tensor->shape();
1508       if (shapes.size() != 4) {
1509         MS_LOG(ERROR) << "unexpected shape size: " << shapes.size();
1510         return false;
1511       }
1512       // suppose the the format is NHWC
1513       auto channels = shapes[3];
1514       if (channels == 0) {
1515         MS_LOG(ERROR) << "unexpected channels: 0";
1516         return false;
1517       }
1518       auto quant_params = tensor->quant_params();
1519       if (quant_params.size() != 1) {
1520         MS_LOG(ERROR) << "unexpected activatation quant_params size: " << quant_params.size();
1521         return false;
1522       }
1523       auto scale = quant_params[0].scale;
1524       auto zp = quant_params[0].zeroPoint;
1525       std::vector<float> dequant_op_output_ch_mean(channels);
1526       auto one_filter_size = elem_count / channels;
1527       for (int i = 0; i < channels; i++) {
1528         float sum = 0;
1529         for (size_t j = 0; j < one_filter_size; j++) {
1530           auto index = j * channels + i;
1531           if (index >= elem_count) {
1532             MS_LOG(ERROR) << "over flow!";
1533             return false;
1534           }
1535           // deuqant activation
1536           auto float_data = scale * (tensor_data[index] - zp);
1537           sum += float_data;
1538         }
1539         if (one_filter_size == 0) {
1540           MS_LOG(ERROR) << "divisor 'one_filter_size' cannot be 0.";
1541           return false;
1542         }
1543         sum = sum / one_filter_size;
1544         dequant_op_output_ch_mean[i] = sum;
1545       }
1546       std::transform(fp32_op_output_ch_mean.begin(), fp32_op_output_ch_mean.end(), dequant_op_output_ch_mean.begin(),
1547                      dequant_op_output_ch_mean.begin(), std::minus<>());
1548 
1549       if (op_bias_diff_map.find(callParam.node_name) != op_bias_diff_map.end()) {
1550         auto &bias_diff = op_bias_diff_map[callParam.node_name];
1551         std::transform(bias_diff.begin(), bias_diff.end(), dequant_op_output_ch_mean.begin(), bias_diff.begin(),
1552                        std::plus<>());
1553       } else {
1554         op_bias_diff_map[callParam.node_name] = dequant_op_output_ch_mean;
1555       }
1556     }
1557     return true;
1558   };
1559   return after_call_back;
1560 }
1561 
GetFloatAfterCallBack()1562 KernelCallBack FullQuantQuantizer::GetFloatAfterCallBack() {
1563   KernelCallBack after_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
1564                                           const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
1565                                           const CallBackParam &callParam) -> bool {
1566     if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
1567       if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
1568         return true;
1569       }
1570       auto tensor = afterOutputs[0];
1571       MS_ASSERT(tensor != nullptr);
1572       const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
1573       size_t elem_count = tensor->ElementsNum();
1574       auto shapes = tensor->shape();
1575       if (shapes.size() != 4) {
1576         MS_LOG(ERROR) << "unexpected shape size: " << shapes.size();
1577         return false;
1578       }
1579       // suppose the activation format: NHWC
1580       auto channels = shapes[3];
1581       if (channels == 0) {
1582         MS_LOG(ERROR) << "unexpected channels: 0";
1583         return false;
1584       }
1585       std::vector<float> fp32_op_output_ch_mean(channels);
1586       auto one_filter_size = elem_count / channels;
1587       for (int i = 0; i < channels; i++) {
1588         float sum = 0;
1589         for (size_t j = 0; j < one_filter_size; j++) {
1590           auto index = j * channels + i;
1591           if (index >= elem_count) {
1592             MS_LOG(ERROR) << "over flow!";
1593             return false;
1594           }
1595           sum += tensor_data[index];
1596         }
1597         if (one_filter_size == 0) {
1598           MS_LOG(ERROR) << "divisor 'one_filter_size' cannot be 0.";
1599           return false;
1600         }
1601         sum = sum / one_filter_size;
1602         fp32_op_output_ch_mean[i] = sum;
1603       }
1604       while (!OpOutputChMeanDataHandle(STORE, callParam.node_name, &fp32_op_output_ch_mean)) {
1605         std::this_thread::sleep_for(std::chrono::milliseconds(kMillisecondsBase));
1606       }
1607     }
1608     return true;
1609   };
1610   return after_call_back;
1611 }
1612 }  // namespace mindspore::lite::quant
1613