1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/quantizer/full_quant_quantizer.h"
18 #include <dirent.h>
19 #include <future>
20 #include <map>
21 #include <set>
22 #include <memory>
23 #include <algorithm>
24 #include <unordered_map>
25 #include <functional>
26 #include <numeric>
27 #include <utility>
28 #include <string>
29 #include <thread>
30 #include <vector>
31 #include <fstream>
32 #include "ops/fusion/full_connection.h"
33 #include "tools/converter/ops/ops_def.h"
34 #include "src/tensor.h"
35 #include "tools/converter/quantizer/quant_cast.h"
36 #include "tools/converter/quantizer/quantize_util.h"
37 #include "tools/optimizer/common/gllo_utils.h"
38 #include "tools/optimizer/common/format_utils.h"
39 #include "src/common/log_adapter.h"
40 #include "securec/include/securec.h"
41 #include "tools/common/tensor_util.h"
42 #include "src/common/quant_utils.h"
43 #include "src/common/utils.h"
44 #include "tools/converter/preprocess/image_preprocess.h"
45 #include "nnacl/op_base.h"
46
47 using std::string;
48 using std::vector;
49
50 namespace mindspore::lite::quant {
51 namespace {
52 static const std::set<PrimitivePtr> has_bias_operator = {prim::kPrimConv2DFusion, prim::kPrimConv2dTransposeFusion,
53 prim::kPrimMatMul, prim::kPrimFullConnection,
54 prim::kPrimLayerNormFusion};
55 constexpr int kMinSize = 0;
56 constexpr int kMaxSize = 65535;
57 } // namespace
58 namespace {
ComputeBiasDataAndQuantParam(const std::vector<double> & bias_scales,const std::vector<double> & input_scales,const float * raw_datas,const QuantParamHolderPtr & quant_param_holder,std::vector<schema::QuantParamT> * quant_params,std::vector<int32_t> * quant_datas)59 STATUS ComputeBiasDataAndQuantParam(const std::vector<double> &bias_scales, const std::vector<double> &input_scales,
60 const float *raw_datas, const QuantParamHolderPtr &quant_param_holder,
61 std::vector<schema::QuantParamT> *quant_params, std::vector<int32_t> *quant_datas) {
62 MS_ASSERT(raw_datas != nullptr && quant_param_holder != nullptr);
63 MS_ASSERT(quant_params != nullptr && quant_datas != nullptr);
64 double bias_scale_tmp;
65 const constexpr double quanted_bias_abs_limit = 0.5 * INT32_MAX;
66 MS_CHECK_TRUE_MSG(quant_param_holder->get_input_quant_params().size() > 1, RET_ERROR, "invalid access.");
67 auto weight_quant_params = quant_param_holder->get_input_quant_params().at(1);
68 auto shape_size = quant_datas->size();
69 if (bias_scales.size() == shape_size) {
70 for (size_t i = 0; i < shape_size; i++) {
71 bias_scale_tmp = bias_scales[i];
72 if (fabs(bias_scale_tmp) <= 0.0f) {
73 MS_LOG(ERROR) << "divisor 'bias_scale_tmp' cannot be 0.";
74 return RET_ERROR;
75 }
76 if (std::abs(raw_datas[i] / bias_scale_tmp) >= quanted_bias_abs_limit) {
77 MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << weight_quant_params[i].scale
78 << " is too small, need to update";
79 // update filter scale and zp
80 double activate_scale = input_scales[0];
81 double filter_scale = std::abs(raw_datas[i]) / (activate_scale * quanted_bias_abs_limit);
82 weight_quant_params[i].scale = filter_scale;
83 weight_quant_params[i].zeroPoint = 0;
84 quant_param_holder->set_input_quant_param(1, weight_quant_params);
85 bias_scale_tmp = std::abs(raw_datas[i]) / quanted_bias_abs_limit;
86 quant_params->at(i).scale = bias_scale_tmp;
87 MS_LOG(DEBUG) << "new filter scale: " << filter_scale;
88 }
89 auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp);
90 quant_datas->at(i) = quant_data;
91 }
92 return RET_OK;
93 } else if (bias_scales.size() == 1) {
94 // for fc, per tensor quant
95 bias_scale_tmp = quant_params->front().scale;
96 float max_raw_data = 0.0f;
97 for (size_t i = 0; i < shape_size; i++) {
98 if (std::abs(raw_datas[i]) > max_raw_data) {
99 max_raw_data = std::abs(raw_datas[i]);
100 }
101 }
102 if (fabs(bias_scale_tmp) <= 0.0f) {
103 MS_LOG(ERROR) << "divisor 'bias_scale_tmp' cannot be 0.";
104 return RET_ERROR;
105 }
106 if (std::abs(max_raw_data / bias_scale_tmp) >= quanted_bias_abs_limit) {
107 MS_LOG(DEBUG) << "quanted bias over flow, maybe the scale of weight: " << weight_quant_params[0].scale
108 << " is too small, need to update";
109 double activate_scale = input_scales[0];
110 MS_CHECK_TRUE_MSG(activate_scale != 0, RET_ERROR, "activate_scale == 0");
111 double filter_scale = std::abs(max_raw_data) / (activate_scale * quanted_bias_abs_limit);
112 weight_quant_params[0].scale = filter_scale;
113 weight_quant_params[0].zeroPoint = 0;
114 quant_param_holder->set_input_quant_param(1, weight_quant_params);
115 bias_scale_tmp = max_raw_data / quanted_bias_abs_limit;
116 quant_params->front().scale = bias_scale_tmp;
117 MS_LOG(DEBUG) << "new filter scale: " << filter_scale;
118 }
119 for (size_t i = 0; i < shape_size; i++) {
120 auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp);
121 quant_datas->at(i) = quant_data;
122 }
123 return RET_OK;
124 }
125 MS_LOG(ERROR) << "unexpected input_scales size: " << input_scales.size()
126 << " weight_scales size: " << weight_quant_params.size();
127 return RET_ERROR;
128 }
129 } // namespace
130
RecordMaxMinValue(const std::vector<float> & data)131 STATUS DivergInfo::RecordMaxMinValue(const std::vector<float> &data) {
132 for (float val : data) {
133 max = std::max(val, max);
134 min = std::min(val, min);
135 }
136 return RET_OK;
137 }
138
RecordMaxMinValueArray(const std::vector<float> & data)139 STATUS DivergInfo::RecordMaxMinValueArray(const std::vector<float> &data) {
140 if (data.empty()) {
141 return RET_ERROR;
142 }
143 float max_num = data.at(0);
144 float min_num = data.at(0);
145 for (float val : data) {
146 max_num = std::max(val, max_num);
147 min_num = std::min(val, min_num);
148 }
149 this->max_datas.emplace_back(max_num);
150 this->min_datas.emplace_back(min_num);
151 return RET_OK;
152 }
153
UpdateInterval()154 void DivergInfo::UpdateInterval() {
155 auto max_value = std::max(fabs(this->max), fabs(this->min));
156 MS_ASSERT(bin_num != 0);
157 this->interval = max_value / static_cast<float>(bin_num);
158 }
159
UpdateHistogram(const std::vector<float> & data)160 STATUS DivergInfo::UpdateHistogram(const std::vector<float> &data) {
161 for (auto value : data) {
162 if (value == 0) {
163 continue;
164 }
165 if (this->interval == 0) {
166 MS_LOG(ERROR) << "divisor 'interval' cannot be 0.";
167 return RET_ERROR;
168 }
169 int bin_index = std::min(static_cast<int>(std::fabs(value) / this->interval), bin_num - 1);
170 this->histogram[bin_index]++;
171 }
172 return RET_OK;
173 }
174
DumpHistogram()175 void DivergInfo::DumpHistogram() {
176 MS_LOG(INFO) << "Print node " << cnode->fullname_with_scope() << " histogram";
177 for (float item : this->histogram) {
178 std::cout << item << " ";
179 }
180 std::cout << std::endl;
181 }
182
HandleBinForKL(int quant_bint_nums,int bin_index,std::vector<float> * quantized_histogram,std::vector<float> * expanded_histogram)183 void DivergInfo::HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
184 std::vector<float> *expanded_histogram) {
185 MS_ASSERT(quantized_histogram != nullptr && expanded_histogram != nullptr);
186 MS_ASSERT(quant_bint_nums != 0);
187 const float bin_interval = static_cast<float>(bin_index) / static_cast<float>(quant_bint_nums);
188 // merge i bins to target bins
189 for (int i = 0; i < quant_bint_nums; ++i) {
190 const float start = i * bin_interval;
191 const float end = start + bin_interval;
192 const int left_upper = static_cast<int>(std::ceil(start));
193 if (left_upper > start) {
194 const double left_scale = left_upper - start;
195 quantized_histogram->at(i) += left_scale * this->histogram[left_upper - 1];
196 }
197 const int right_lower = static_cast<int>(std::floor(end));
198 if (right_lower < end) {
199 const double right_scale = end - right_lower;
200 quantized_histogram->at(i) += right_scale * this->histogram[right_lower];
201 }
202 std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower,
203 [&quantized_histogram, i](float item) { quantized_histogram->at(i) += item; });
204 }
205 // expand target bins to i bins in order to calculate KL with reference_histogram
206 for (int i = 0; i < quant_bint_nums; ++i) {
207 const float start = i * bin_interval;
208 const float end = start + bin_interval;
209 float count = 0;
210 const int left_upper = static_cast<int>(std::ceil(start));
211 float left_scale = 0.0f;
212 if (left_upper > start) {
213 left_scale = left_upper - start;
214 if (this->histogram[left_upper - 1] != 0) {
215 count += left_scale;
216 }
217 }
218 const int right_lower = static_cast<int>(std::floor(end));
219 double right_scale = 0.0f;
220 if (right_lower < end) {
221 right_scale = end - right_lower;
222 if (this->histogram[right_lower] != 0) {
223 count += right_scale;
224 }
225 }
226 std::for_each(this->histogram.begin() + left_upper, this->histogram.begin() + right_lower, [&count](float item) {
227 if (item != 0) {
228 count += 1;
229 }
230 });
231 if (count == 0) {
232 continue;
233 }
234 const float average_num = quantized_histogram->at(i) / count;
235 if (left_upper > start && this->histogram[left_upper - 1] != 0) {
236 expanded_histogram->at(left_upper - 1) += average_num * left_scale;
237 }
238 if (right_lower < end && this->histogram[right_lower] != 0) {
239 expanded_histogram->at(right_lower) += average_num * right_scale;
240 }
241 for (int k = left_upper; k < right_lower; ++k) {
242 if (this->histogram[k] != 0) {
243 expanded_histogram->at(k) += average_num;
244 }
245 }
246 }
247 }
248
ComputeThreshold()249 STATUS DivergInfo::ComputeThreshold() {
250 if (activation_quant_method == MAX_MIN) {
251 this->best_T = std::max(fabs(this->max), fabs(this->min));
252 MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T;
253 return RET_OK;
254 }
255
256 if (activation_quant_method == REMOVAL_OUTLIER && !this->min_datas.empty()) {
257 this->percent_result = OutlierMethod(min_datas, max_datas);
258 this->best_T = std::max(std::fabs(percent_result.first), std::fabs(percent_result.second));
259 return RET_OK;
260 }
261
262 int threshold = INT8_MAX + 1;
263 float min_kl = FLT_MAX;
264 float after_threshold_sum = std::accumulate(this->histogram.begin() + INT8_MAX + 1, this->histogram.end(), 0.0f);
265
266 for (int i = INT8_MAX + 1; i < this->bin_num; ++i) {
267 std::vector<float> quantized_histogram(INT8_MAX + 1, 0);
268 std::vector<float> reference_histogram(this->histogram.begin(), this->histogram.begin() + i);
269 std::vector<float> expanded_histogram(i, 0);
270 reference_histogram[i - 1] += after_threshold_sum;
271 after_threshold_sum -= this->histogram[i];
272 // handle bins for computing KL.
273 HandleBinForKL(INT8_MAX + 1, i, &quantized_histogram, &expanded_histogram);
274 auto KLDivergence = [](std::vector<float> p, std::vector<float> q) {
275 auto sum = 0.0f;
276 std::for_each(p.begin(), p.end(), [&sum](float item) { sum += item; });
277 std::for_each(p.begin(), p.end(), [sum](float &item) { item /= sum; });
278 sum = 0.0f;
279 std::for_each(q.begin(), q.end(), [&sum](float item) { sum += item; });
280 std::for_each(q.begin(), q.end(), [sum](float &item) { item /= sum; });
281
282 float result = 0.0f;
283 const int size = p.size();
284 for (int i = 0; i < size; ++i) {
285 if (p[i] != 0) {
286 if (q[i] == 0) {
287 result += 1.0f;
288 } else {
289 result += (p[i] * std::log((p[i]) / (q[i])));
290 }
291 }
292 }
293 return result;
294 };
295 const float kl = KLDivergence(reference_histogram, expanded_histogram);
296 if (kl < min_kl) {
297 min_kl = kl;
298 threshold = i;
299 }
300 }
301 this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval;
302 MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold << " T: " << best_T
303 << " max: " << std::max(fabs(this->max), fabs(this->min));
304 return RET_OK;
305 }
306
GetScale()307 std::pair<CNodePtr, float> DivergInfo::GetScale() {
308 float max_value = this->best_T;
309 float min_value = -max_value;
310
311 if (this->activation_quant_method == REMOVAL_OUTLIER) {
312 min_value = percent_result.first;
313 max_value = percent_result.second;
314 }
315
316 MS_CHECK_TRUE_MSG(quant_max - quant_min != 0, {}, "quant_max - quant_min == 0");
317 float scale = (max_value - min_value) / (quant_max - quant_min);
318 this->scale_tmp = scale;
319 MS_ASSERT(fabs(scale) <= 0.0f);
320 return std::make_pair(this->cnode, scale);
321 }
322
GetZeropoint()323 std::pair<CNodePtr, int32_t> DivergInfo::GetZeropoint() {
324 int zero_point = 0;
325 if (quant_min == 0 && quant_max == UINT8_MAX) {
326 zero_point = INT8_MAX + 1;
327 } else if (quant_min == INT_LEAST8_MIN + 1 && quant_max == INT8_MAX) {
328 zero_point = 0;
329 } else {
330 MS_LOG(WARNING) << "unexpected quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
331 }
332 if (this->activation_quant_method == REMOVAL_OUTLIER) {
333 MS_CHECK_TRUE_MSG(fabs(scale_tmp) <= 0.0f, {}, "fabs(scale_tmp) > 0.0f");
334 zero_point = std::round(quant_max - percent_result.second / scale_tmp);
335 }
336 return std::make_pair(this->cnode, zero_point);
337 }
338
GetInputDivergInfo()339 std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetInputDivergInfo() {
340 return &this->inputs_diverg_info_;
341 }
342
GetOutputDivergInfo()343 std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *Calibrator::GetOutputDivergInfo() {
344 return &this->outputs_diverg_info_;
345 }
346
RecordMaxMinValue(const vector<float> & data,const std::unique_ptr<DivergInfo> & diverg_info)347 STATUS Calibrator::RecordMaxMinValue(const vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
348 auto ret = diverg_info->RecordMaxMinValue(data);
349 if (ret != RET_OK) {
350 MS_LOG(ERROR) << "Record max min value failed.";
351 return ret;
352 }
353 ret = diverg_info->RecordMaxMinValueArray(data);
354 if (ret != RET_OK) {
355 MS_LOG(ERROR) << "Record max min value array failed.";
356 return ret;
357 }
358 return RET_OK;
359 }
360
ComputeThreshold()361 STATUS Calibrator::ComputeThreshold() {
362 for (auto &kv : this->outputs_diverg_info_) {
363 auto &outputs_diverg_info = kv.second;
364 for (auto &diverg_info : outputs_diverg_info) {
365 auto ret = diverg_info->ComputeThreshold();
366 if (ret != RET_OK) {
367 MS_LOG(ERROR) << "Compute threshold failed.";
368 return ret;
369 }
370 }
371 }
372 // node A's input may be node B's output, no need to re-compute the node A's input quant param which is the same as
373 for (auto &kv : this->inputs_diverg_info_) {
374 auto &input_infos = kv.second;
375 for (size_t i = 0; i < input_infos.size(); i++) {
376 auto cnode = input_infos[i]->cnode;
377 bool already_computed = false;
378 auto input = cnode->input(i + 1);
379 if (input->isa<mindspore::CNode>()) {
380 auto input_cnode = input->cast<CNodePtr>();
381 for (const auto &outputs_diverg_info : outputs_diverg_info_) {
382 if (already_computed) {
383 break;
384 }
385 for (const auto &output_diverg_info : outputs_diverg_info.second) {
386 auto output_diverg_cnode = output_diverg_info->cnode;
387 if (output_diverg_cnode == input_cnode) {
388 if (NodePrimitiveType(input_cnode) != lite::kNameTupleGetItem) {
389 *(input_infos[i]) = *output_diverg_info;
390 input_infos[i]->cnode = cnode;
391 already_computed = true;
392 break;
393 }
394 }
395 }
396 }
397 }
398 if (!already_computed) {
399 auto ret = input_infos[i]->ComputeThreshold();
400 if (ret != RET_OK) {
401 MS_LOG(ERROR) << "ComputeThreshold failed.";
402 return ret;
403 }
404 }
405 }
406 }
407 return RET_OK;
408 }
409
UpdateDivergInterval(std::unordered_map<std::string,std::vector<std::unique_ptr<DivergInfo>>> * diverg_info)410 STATUS Calibrator::UpdateDivergInterval(
411 std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info) {
412 MS_ASSERT(diverg_info != nullptr);
413 for (auto &kv : *diverg_info) {
414 for (auto &info : kv.second) {
415 info->UpdateInterval();
416 }
417 }
418 return RET_OK;
419 }
420
UpdateDataFrequency(const vector<float> & data,const std::unique_ptr<DivergInfo> & diverg_info)421 STATUS Calibrator::UpdateDataFrequency(const vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info) {
422 MS_ASSERT(diverg_info != nullptr);
423 return diverg_info->UpdateHistogram(data);
424 }
425
AddQuantizedOp(const CNodePtr & cnode)426 STATUS Calibrator::AddQuantizedOp(const CNodePtr &cnode) {
427 if (cnode == nullptr) {
428 MS_LOG(ERROR) << "To be quantized cnode is null";
429 return RET_ERROR;
430 }
431 string node_name = cnode->fullname_with_scope();
432 std::unique_ptr<DivergInfo> input_diverg = std::make_unique<DivergInfo>(
433 cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
434 MS_CHECK_TRUE_MSG(input_diverg != nullptr, RET_NULL_PTR, "input_diverg is nullptr.");
435 std::unique_ptr<DivergInfo> output_diverg = std::make_unique<DivergInfo>(
436 cnode, kDefaultBinNumber, bit_num_, quant_max_, quant_min_, full_quant_param_.activation_quant_method);
437 MS_CHECK_TRUE_MSG(output_diverg != nullptr, RET_NULL_PTR, "output_diverg is nullptr.");
438 inputs_diverg_info_[node_name].push_back(std::move(input_diverg));
439 outputs_diverg_info_[node_name].push_back(std::move(output_diverg));
440 return RET_OK;
441 }
442
GenerateInputData(const std::string & input_name,size_t image_index,mindspore::tensor::MSTensor * tensor) const443 STATUS Calibrator::GenerateInputData(const std::string &input_name, size_t image_index,
444 mindspore::tensor::MSTensor *tensor) const {
445 return preprocess::PreProcess(data_pre_process_param_, input_name, image_index, tensor);
446 }
447
FullQuantQuantizer(FuncGraphPtr graph,int bit_num,TypeId target_type,bool per_channel)448 FullQuantQuantizer::FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type, bool per_channel)
449 : Quantizer(std::move(graph)) {
450 MS_ASSERT(graph != nullptr);
451 this->per_channel_ = per_channel;
452 this->bit_num = bit_num;
453 this->target_type_ = target_type;
454 if (target_type == kNumberTypeInt8) {
455 quant_max = (1 << (this->bit_num - 1)) - 1; // 127
456 quant_min = -quant_max; // -127
457 } else if (target_type == kNumberTypeUInt8) {
458 quant_max = (1 << this->bit_num) - 1; // 255
459 quant_min = 0;
460 } else {
461 MS_LOG(ERROR) << "unsupported quant value type: " << target_type;
462 }
463 calibrator_ = std::make_unique<Calibrator>(this->bit_num, quant_max, quant_min);
464 if (calibrator_ == nullptr) {
465 MS_LOG(ERROR) << "create calibrator failed!";
466 return;
467 }
468 }
469
~FullQuantQuantizer()470 FullQuantQuantizer::~FullQuantQuantizer() {
471 delete fp32_session_;
472 delete fp32_model_;
473 delete int8_session_;
474 delete int8_model_;
475 }
476
SetInOutQuantParam(const AnfNodePtr & input_node,const std::unique_ptr<DivergInfo> & info,const PrimitivePtr & primitive,bool is_input,size_t index) const477 STATUS FullQuantQuantizer::SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DivergInfo> &info,
478 const PrimitivePtr &primitive, bool is_input, size_t index) const {
479 auto quant_param_holder = GetCNodeQuantHolder(primitive);
480 MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
481 schema::QuantParamT quant_param;
482 TypeId type_id = kTypeUnknown;
483 if (opt::GetDataTypeFromAnfNode(input_node, &type_id) != RET_OK) {
484 MS_LOG(ERROR) << "Get data type failed.";
485 return RET_ERROR;
486 }
487 if (type_id == kNumberTypeFloat32 && info != nullptr) {
488 auto scale = info->GetScale().second;
489 if (scale == 0) {
490 MS_LOG(WARNING) << "The input or output values are very close to 0, so set the scale to 1.";
491 quant_param.scale = 1;
492 } else {
493 quant_param.scale = scale;
494 }
495 quant_param.zeroPoint = info->GetZeropoint().second;
496 quant_param.max = info->max;
497 quant_param.min = info->min;
498 quant_param.numBits = bit_num;
499 quant_param.narrowRange = false;
500 quant_param.inited = true;
501 quant_param.roundType = 1;
502 quant_param.multiplier = 1;
503 } else {
504 quant_param.inited = false;
505 }
506 std::vector<schema::QuantParamT> quant_params = {quant_param};
507 if (is_input) {
508 quant_param_holder->set_input_quant_param(index, quant_params);
509 } else {
510 quant_param_holder->set_output_quant_param(index, quant_params);
511 }
512 return RET_OK;
513 }
514
DoWeightQuant(const std::string & op_name,const AnfNodePtr & weight,const PrimitivePtr & primitive,bool per_channel,int input_index) const515 STATUS FullQuantQuantizer::DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight,
516 const PrimitivePtr &primitive, bool per_channel, int input_index) const {
517 MS_ASSERT(weight != nullptr);
518 MS_ASSERT(primitive != nullptr);
519 // perlayer
520 if (!weight->isa<Parameter>()) {
521 MS_LOG(ERROR) << "not a parameter";
522 return RET_PARAM_INVALID;
523 }
524 auto parameter = weight->cast<ParameterPtr>();
525 if (parameter == nullptr) {
526 MS_LOG(ERROR) << weight->fullname_with_scope() << " can not cast to Parameter";
527 return RET_NULL_PTR;
528 }
529 auto tensor_info = parameter->default_param()->cast<tensor::TensorPtr>();
530 if (tensor_info == nullptr) {
531 MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
532 return RET_NULL_PTR;
533 }
534 auto bit_num_t = bit_num;
535 auto quant_max_t = quant_max;
536 auto quant_min_t = quant_min;
537 auto weight_quant_type = per_channel ? WeightQuantType::FIXED_BIT_PER_CHANNEL : WeightQuantType::FIXED_BIT_PER_LAYER;
538 auto status = FixedBitQuantFilter<int8_t>(tensor_info, primitive, QuantType_QUANT_ALL, quant_max_t, quant_min_t,
539 bit_num_t, weight_quant_type, kNumberTypeInt8, input_index - 1);
540 if (status != RET_OK) {
541 MS_LOG(ERROR) << "QuantFilter failed: " << status;
542 return status;
543 }
544 // set dtype
545 auto abstractBase = parameter->abstract();
546 if (abstractBase == nullptr) {
547 MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << parameter->name();
548 return RET_NULL_PTR;
549 }
550 if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
551 MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << parameter->name();
552 return RET_ERROR;
553 }
554 auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
555 if (abstractTensor == nullptr || abstractTensor->element() == nullptr) {
556 MS_LOG(ERROR) << "abstractTensor is nullptr, " << parameter->name();
557 return RET_NULL_PTR;
558 }
559 abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt8));
560 return RET_OK;
561 }
562
DoBiasQuant(const AnfNodePtr & bias,const PrimitivePtr & primitive)563 STATUS FullQuantQuantizer::DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive) {
564 if (primitive == nullptr || bias == nullptr) {
565 MS_LOG(ERROR) << "null pointer!";
566 return RET_NULL_PTR;
567 }
568 auto bias_parameter_ptr = bias->cast<ParameterPtr>();
569 MS_ASSERT(bias_parameter_ptr != nullptr);
570 auto bias_default_param = bias_parameter_ptr->default_param();
571 auto bias_param = bias_default_param->cast<tensor::TensorPtr>();
572 MS_ASSERT(bias_parameter != nullptr);
573 auto quant_param_holder = GetCNodeQuantHolder(primitive);
574 MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
575 auto active_weight_quant_params = quant_param_holder->get_input_quant_params();
576
577 auto active_params = active_weight_quant_params.at(FIRST_INPUT);
578 auto weight_params = active_weight_quant_params.at(SECOND_INPUT);
579
580 vector<double> input_scales;
581 vector<double> filter_scales;
582 vector<double> bias_scales;
583 size_t sizeX = active_params.size();
584 for (size_t i = 0; i < sizeX; i++) {
585 input_scales.emplace_back(active_params[i].scale);
586 }
587 size_t sizeY = weight_params.size();
588 if (sizeX != sizeY) {
589 if (sizeX > 1 && sizeY > 1) {
590 MS_LOG(ERROR) << "input and filter's scale count cannot match!";
591 return RET_ERROR;
592 }
593 }
594 for (size_t i = 0; i < sizeY; i++) {
595 filter_scales.emplace_back(weight_params[i].scale);
596 }
597 size_t size = std::max(sizeX, sizeY);
598 for (size_t i = 0; i < size; i++) {
599 auto scaleX = sizeX > 1 ? input_scales[i] : input_scales[0];
600 auto scaleY = sizeY > 1 ? filter_scales[i] : filter_scales[0];
601 bias_scales.push_back(scaleX * scaleY);
602 }
603 MS_ASSERT(!bias_scales.empty());
604 size_t shape_size = bias_param->DataSize();
605
606 // set bias quant param
607 std::vector<schema::QuantParamT> quant_params;
608 for (double bias_scale : bias_scales) {
609 schema::QuantParamT quant_param;
610 if (bias_scale == 0) {
611 MS_LOG(WARNING) << "bias_scale == 0";
612 quant_param.scale = 1;
613 } else {
614 quant_param.scale = bias_scale;
615 }
616 quant_param.zeroPoint = 0;
617 quant_param.inited = true;
618 quant_params.emplace_back(quant_param);
619 }
620 // quant bias data
621 std::vector<int32_t> quant_datas(shape_size);
622
623 auto *raw_datas = static_cast<float *>(bias_param->data_c());
624 if (ComputeBiasDataAndQuantParam(bias_scales, input_scales, raw_datas, quant_param_holder, &quant_params,
625 &quant_datas) != RET_OK) {
626 MS_LOG(ERROR) << "compute bias data failed.";
627 return RET_ERROR;
628 }
629 quant_param_holder->set_input_quant_param(THIRD_INPUT, quant_params);
630 auto ret = SetTensorData(bias_param, quant_datas.data(), shape_size * sizeof(int32_t));
631 if (ret != RET_OK) {
632 MS_LOG(ERROR) << "set tensor data failed.";
633 return RET_ERROR;
634 }
635 // set dtype
636 auto abstractBase = bias_parameter_ptr->abstract();
637 if (abstractBase == nullptr) {
638 MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << bias_parameter_ptr->name();
639 return RET_ERROR;
640 }
641 if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
642 MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << bias_parameter_ptr->name();
643 return RET_ERROR;
644 }
645 auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
646 if (abstractTensor == nullptr || abstractTensor->element() == nullptr) {
647 MS_LOG(ERROR) << "abstractTensor is nullptr" << bias_parameter_ptr->name();
648 return RET_NULL_PTR;
649 }
650 abstractTensor->element()->set_type(TypeIdToType(kNumberTypeInt32));
651 return RET_OK;
652 }
653
DoParameterNodeQuant(const CNodePtr & cnode,const AnfNodePtr & input_node,size_t input_index)654 STATUS FullQuantQuantizer::DoParameterNodeQuant(const CNodePtr &cnode, const AnfNodePtr &input_node,
655 size_t input_index) {
656 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
657 if (primitive == nullptr) {
658 return RET_ERROR;
659 }
660 auto op_name = cnode->fullname_with_scope();
661 STATUS ret;
662 TypeId type_id = kTypeUnknown;
663 if (opt::GetDataTypeFromAnfNode(input_node, &type_id) != RET_OK) {
664 MS_LOG(ERROR) << "Get data type failed.";
665 return RET_ERROR;
666 }
667 // support for share weight.
668 if (type_id == kNumberTypeInt8) {
669 return RET_CONTINUE;
670 }
671 if (type_id != kNumberTypeFloat32) {
672 ret = SetInOutQuantParam(input_node, nullptr, primitive, true, input_index - 1);
673 if (ret != RET_OK) {
674 MS_LOG(ERROR) << "Set In/Out quant param failed.";
675 return ret;
676 }
677 return RET_QUANT_CONTINUE;
678 }
679 if (CheckNodeInSet(cnode, has_bias_operator)) {
680 if (input_index == FOURTH_INPUT) {
681 ret = DoBiasQuant(input_node, primitive);
682 if (ret != RET_OK) {
683 MS_LOG(ERROR) << "Do bias quant failed.";
684 return ret;
685 }
686 } else {
687 if (opt::CheckPrimitiveType(cnode, prim::kPrimMatMul)) {
688 ret = DoWeightQuant(op_name, input_node, primitive, false, input_index);
689 } else {
690 ret = DoWeightQuant(op_name, input_node, primitive, true, input_index);
691 }
692 if (ret != RET_OK) {
693 MS_LOG(ERROR) << "Do bias quant failed.";
694 return ret;
695 }
696 }
697 } else {
698 ret = DoWeightQuant(op_name, input_node, primitive, false, input_index);
699 if (ret != RET_OK) {
700 MS_LOG(ERROR) << "Do bias quant failed.";
701 return ret;
702 }
703 }
704 return RET_OK;
705 }
706
QuantNodeSimpleOp(const CNodePtr & cnode)707 STATUS FullQuantQuantizer::QuantNodeSimpleOp(const CNodePtr &cnode) {
708 MS_ASSERT(cnode != nullptr);
709 auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
710 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
711 if (primitive == nullptr) {
712 return RET_ERROR;
713 }
714 auto op_name = cnode->fullname_with_scope();
715 auto primitive_quant_holder = GetCNodeQuantHolder(primitive);
716 MS_CHECK_TRUE_MSG(primitive_quant_holder != nullptr, RET_NULL_PTR, "primitive_quant_holder is nullptr.");
717 size_t activation_input_index = 0;
718 STATUS ret;
719 for (size_t i = 1; i < cnode->inputs().size(); i++) {
720 auto input_node = cnode->input(i);
721 MS_ASSERT(input_node != nullptr);
722 bool is_graph_input = false;
723 if (input_node->isa<Parameter>()) {
724 if (!input_node->cast<ParameterPtr>()->has_default()) {
725 is_graph_input = true;
726 }
727 }
728 if (is_graph_input) {
729 // do input quant
730 auto &info = (*inputs_diverg_info)[op_name][activation_input_index++];
731 ret = SetInOutQuantParam(input_node, info, primitive, true, i - 1);
732 if (ret != RET_OK) {
733 MS_LOG(ERROR) << "Set activation quant failed.";
734 return ret;
735 }
736 } else if (input_node->isa<mindspore::CNode>()) {
737 auto input_cnode = input_node->cast<mindspore::CNodePtr>();
738 auto input_cnode_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
739 if (input_cnode_primitive == nullptr) {
740 MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": "
741 << " Primitive is null";
742 continue;
743 }
744 auto input_primitive_quant_holder = GetCNodeQuantHolder(input_cnode_primitive);
745 MS_CHECK_TRUE_MSG(input_primitive_quant_holder != nullptr, RET_NULL_PTR,
746 "input_primitive_quant_holder is nullptr.");
747 if (input_primitive_quant_holder->IsOutputQuantParamsInited()) {
748 auto quant_param = input_primitive_quant_holder->get_output_quant_params().front();
749 primitive_quant_holder->set_input_quant_param(i - 1, quant_param);
750 activation_input_index++;
751 } else {
752 // do input quant
753 auto &info = (*inputs_diverg_info)[op_name][activation_input_index++];
754 ret = SetInOutQuantParam(input_node, info, primitive, true, i - 1);
755 if (ret != RET_OK) {
756 MS_LOG(ERROR) << "Set activation quant failed.";
757 return ret;
758 }
759 }
760 } else if (input_node->isa<mindspore::Parameter>()) {
761 ret = DoParameterNodeQuant(cnode, input_node, i);
762 if (ret == RET_QUANT_CONTINUE) {
763 continue;
764 } else if (ret != RET_OK) {
765 MS_LOG(ERROR) << "Do parameter node quant failed.";
766 return ret;
767 }
768 } else {
769 MS_LOG(ERROR) << input_node->fullname_with_scope() << ":" << input_node->type_name() << " is not support type";
770 return RET_ERROR;
771 }
772 }
773 return RET_OK;
774 }
775
QuantNode()776 STATUS FullQuantQuantizer::QuantNode() {
777 auto inputs_diverg_info = calibrator_->GetInputDivergInfo();
778 auto outputs_diverg_info = calibrator_->GetOutputDivergInfo();
779
780 auto cnodes = funcGraph->GetOrderedCnodes();
781 for (auto &cnode : cnodes) {
782 auto op_name = cnode->fullname_with_scope();
783 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
784 if (primitive == nullptr) {
785 MS_LOG(ERROR) << "primitive is nullptr";
786 continue;
787 }
788 auto primitive_quant_holder = GetCNodeQuantHolder(primitive);
789 MS_CHECK_TRUE_MSG(primitive_quant_holder != nullptr, RET_NULL_PTR, "primitive_quant_holder is nullptr.");
790 if (inputs_diverg_info->find(op_name) == inputs_diverg_info->end()) {
791 MS_LOG(INFO) << op_name << " can not do quant";
792 primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_NONE);
793 continue;
794 }
795
796 auto op_type = primitive->name();
797 MS_LOG(DEBUG) << "OpName: " << op_name;
798 if (op_type == lite::kNameTupleGetItem) {
799 constexpr int tuple_get_item_input_size = 3;
800 MS_CHECK_TRUE_MSG(cnode->size() == tuple_get_item_input_size, RET_ERROR, "cnode->size() != 3");
801 auto index_node = cnode->input(THIRD_INPUT);
802 auto index_value_node = index_node->cast<mindspore::ValueNodePtr>();
803 if (index_value_node == nullptr) {
804 MS_LOG(WARNING) << "index value node is null";
805 continue;
806 }
807 size_t index = opt::CastToInt(index_value_node->value()).front();
808 auto input_node = cnode->input(SECOND_INPUT);
809 MS_CHECK_TRUE_MSG(input_node != nullptr, RET_ERROR, "input_node == nullptr");
810 auto input_cnode = input_node->cast<mindspore::CNodePtr>();
811 MS_CHECK_TRUE_MSG(input_cnode != nullptr, RET_ERROR, "input_cnode == nullptr");
812 auto input_cnode_primitive = GetValueNode<PrimitivePtr>(input_cnode->input(0));
813 if (input_cnode_primitive == nullptr) {
814 MS_LOG(WARNING) << "input_cnode_primitive is null";
815 continue;
816 }
817 auto input_primitive_quant_holder = GetCNodeQuantHolder(input_cnode_primitive);
818 MS_CHECK_TRUE_MSG(input_primitive_quant_holder != nullptr, RET_NULL_PTR,
819 "input_primitive_quant_holder is nullptr.");
820
821 if (input_primitive_quant_holder->get_output_quant_params().size() > index) {
822 auto quant_param = input_primitive_quant_holder->get_output_quant_params()[index];
823 primitive_quant_holder->set_input_quant_param(0, quant_param);
824 primitive_quant_holder->set_output_quant_param(0, quant_param);
825 } else {
826 MS_LOG(WARNING) << "this TupleGetItem node's input node: " << input_cnode->fullname_with_scope()
827 << "'s output quant_params size: "
828 << input_primitive_quant_holder->get_output_quant_params().size() << ", but index: " << index;
829 }
830 primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
831 continue;
832 } else { // do simple op quant
833 auto status = QuantNodeSimpleOp(cnode);
834 if (status != RET_OK) {
835 MS_LOG(ERROR) << "simple op quant failed.";
836 return status;
837 }
838 }
839 // do output quant, there may multi-output
840 auto &infos = (*outputs_diverg_info)[op_name];
841 for (size_t index = 0; index < infos.size(); index++) {
842 auto &info = infos.at(index);
843 auto ret = SetInOutQuantParam(cnode, info, primitive, false, index);
844 if (ret != RET_OK) {
845 MS_LOG(ERROR) << "Set In/Out quant param failed.";
846 return ret;
847 }
848 primitive_quant_holder->set_quant_type(schema::QuantType_QUANT_ALL);
849 }
850 }
851 return RET_OK;
852 }
853
UpdateDivergeInterval()854 STATUS FullQuantQuantizer::UpdateDivergeInterval() {
855 auto ret = this->calibrator_->UpdateDivergInterval(this->calibrator_->GetInputDivergInfo());
856 if (ret != RET_OK) {
857 MS_LOG(ERROR) << "Update input diverge interval failed.";
858 return ret;
859 }
860 ret = this->calibrator_->UpdateDivergInterval(this->calibrator_->GetOutputDivergInfo());
861 if (ret != RET_OK) {
862 MS_LOG(ERROR) << "Update output diverge interval failed.";
863 return ret;
864 }
865 return RET_OK;
866 }
867
868 /**
869 * Mark quantifiable nodes
870 **/
PreProcess()871 STATUS FullQuantQuantizer::PreProcess() {
872 auto cnodes = funcGraph->GetOrderedCnodes();
873 for (auto &cnode : cnodes) {
874 AnfNodePtr anf = cnode->cast<AnfNodePtr>();
875 if (anf == nullptr) {
876 MS_LOG(ERROR) << " cnode is null";
877 return RET_NULL_PTR;
878 }
879 if (mindspore::lite::quant::QuantStrategy::CanOpFullQuantized(anf)) {
880 auto ret = calibrator_->AddQuantizedOp(cnode);
881 if (ret != RET_OK) {
882 MS_LOG(ERROR) << "Add Quantized Op failed.";
883 return ret;
884 }
885 }
886 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
887 if (primitive == nullptr) {
888 MS_LOG(ERROR) << cnode->fullname_with_scope() << " primitive is null";
889 continue;
890 }
891 auto quant_param_holder = GetCNodeQuantHolder(primitive);
892 MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
893 quant_param_holder->ClearInputOutputQuantParam();
894 }
895 return RET_OK;
896 }
897
CheckFp32TensorVec(const std::string & node_name,const std::vector<mindspore::tensor::MSTensor * > & tensor_vec)898 STATUS FullQuantQuantizer::CheckFp32TensorVec(const std::string &node_name,
899 const std::vector<mindspore::tensor::MSTensor *> &tensor_vec) {
900 MS_ASSERT(tensor_vec != nullptr);
901 if (tensor_vec.empty()) {
902 MS_LOG(ERROR) << "node: " << node_name << " input tensors is 0";
903 return RET_ERROR;
904 }
905 auto *tensor = tensor_vec[0];
906 MS_ASSERT(tensor != nullptr);
907 if (tensor->data_type() != kNumberTypeFloat32) {
908 MS_LOG(INFO) << "node: " << node_name << " will not quantize"
909 << " tensor data_type: " << tensor->data_type();
910 return RET_ERROR;
911 }
912 return RET_OK;
913 }
914
915 /**
916 * 1. create input tensor
917 * 2. insert callback to session
918 * 3. run session
919 **/
DoInference()920 STATUS FullQuantQuantizer::DoInference() {
921 // get input tensor
922 vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
923 if (inputs.size() != calibrator_->GetInputNum()) {
924 MS_LOG(ERROR) << "model's input tensor count: " << inputs.size() << " != "
925 << " calibrator count:" << calibrator_->GetInputNum();
926 return RET_ERROR;
927 }
928
929 for (size_t calib_index = 0; calib_index < calibrator_->GetBatchNum(); calib_index++) {
930 // set multi-input data
931 for (size_t input_index = 0; input_index < inputs.size(); input_index++) {
932 STATUS status =
933 calibrator_->GenerateInputData(inputs[input_index]->tensor_name(), calib_index, inputs[input_index]);
934 if (status != RET_OK) {
935 MS_LOG(ERROR) << "generate input data from images failed!";
936 return RET_ERROR;
937 }
938 }
939
940 KernelCallBack beforeCallBack = [&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
941 const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
942 const CallBackParam &callParam) -> bool {
943 auto diverg_info_map = calibrator_->GetInputDivergInfo();
944 if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
945 return true;
946 }
947 if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, beforeOutputs) != RET_OK) {
948 return true;
949 }
950 bool is_init = beforeInputs.size() > 1 && (*diverg_info_map)[callParam.node_name].size() == 1;
951 if (is_init) {
952 for (size_t i = 1; i < beforeInputs.size(); i++) {
953 if (beforeInputs.at(i)->data_type() != kNumberTypeFloat32 || beforeInputs.at(i)->IsConst()) {
954 continue;
955 }
956 auto input_diverg = std::make_unique<DivergInfo>();
957 MS_CHECK_TRUE_MSG(input_diverg != nullptr, false, "input_diverg is nullptr.");
958 *input_diverg = *((*diverg_info_map)[callParam.node_name][0]);
959 (*diverg_info_map)[callParam.node_name].push_back(std::move(input_diverg));
960 }
961 }
962 for (size_t i = 0; i < (*diverg_info_map)[callParam.node_name].size(); i++) {
963 auto tensor = beforeInputs[i];
964 MS_ASSERT(tensor != nullptr);
965 const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
966 MS_CHECK_TRUE_MSG(tensor_data != nullptr, false, "tensor_data is nullptr.");
967 size_t elem_count = tensor->ElementsNum();
968 vector<float> data(tensor_data, tensor_data + elem_count);
969 auto ret = this->calibrator_->RecordMaxMinValue(data, (*diverg_info_map)[callParam.node_name][i]);
970 MS_CHECK_TRUE_MSG(ret == RET_OK, false, "Record MaxMinValue failed!");
971 }
972 return true;
973 };
974 // func
975 KernelCallBack afterCallBack = [&](const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
976 const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
977 const CallBackParam &callParam) -> bool {
978 auto diverg_info_map = calibrator_->GetOutputDivergInfo();
979 if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
980 return true;
981 }
982 if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
983 return true;
984 }
985 bool is_init = afterOutputs.size() > 1 && (*diverg_info_map)[callParam.node_name].size() == 1;
986 if (is_init) {
987 for (size_t i = 1; i < afterOutputs.size(); i++) {
988 auto output_diverg = std::make_unique<DivergInfo>();
989 CHECK_NULL_RETURN(output_diverg);
990 *output_diverg = *((*diverg_info_map)[callParam.node_name][0]);
991 (*diverg_info_map)[callParam.node_name].push_back(std::move(output_diverg));
992 }
993 }
994 size_t output_i = 0;
995 for (const auto &tensor : afterOutputs) {
996 const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
997 CHECK_NULL_RETURN(tensor_data);
998 size_t elem_count = tensor->ElementsNum();
999 vector<float> data(tensor_data, tensor_data + elem_count);
1000 auto ret = this->calibrator_->RecordMaxMinValue(data, (*diverg_info_map)[callParam.node_name][output_i]);
1001 MS_CHECK_TRUE_MSG(ret == RET_OK, false, "Record MaxMinValue failed!");
1002 output_i++;
1003 }
1004 return true;
1005 };
1006 auto status = fp32_session_->RunGraph(beforeCallBack, afterCallBack);
1007 if (status != RET_OK) {
1008 MS_LOG(ERROR) << "run model failed!";
1009 return RET_ERROR;
1010 }
1011 }
1012 return RET_OK;
1013 }
1014
Int8Inference()1015 STATUS FullQuantQuantizer::Int8Inference() {
1016 // int8 inference
1017 vector<mindspore::tensor::MSTensor *> inputs = int8_session_->GetInputs();
1018 for (auto input_tensor : inputs) {
1019 // get input tensor
1020 auto elem_count = input_tensor->ElementsNum();
1021 vector<float> dummy_data(elem_count);
1022 // set the input data to 0.1
1023 std::fill(dummy_data.begin(), dummy_data.end(), 0.1);
1024 auto ret =
1025 memcpy_s(input_tensor->MutableData(), input_tensor->Size(), dummy_data.data(), sizeof(float) * dummy_data.size());
1026 if (ret != EOK) {
1027 MS_LOG(ERROR) << "memcpy_s error: " << ret;
1028 return RET_ERROR;
1029 }
1030 }
1031
1032 for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
1033 // before func
1034 KernelCallBack before_call_back = GetBeforeCallBack(true);
1035 // after func
1036 KernelCallBack after_call_back = GetAfterCallBack(true);
1037 auto ret = int8_session_->RunGraph(before_call_back, after_call_back);
1038 if (ret != RET_OK) {
1039 MS_LOG(ERROR) << "run model failed!";
1040 return RET_ERROR;
1041 }
1042 } // end for images
1043 return RET_OK;
1044 }
1045
BiasCorrection(const FuncGraphPtr & func_graph)1046 STATUS FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph) {
1047 std::future<STATUS> int8_inference = std::async(std::launch::async, &FullQuantQuantizer::Int8Inference, this);
1048 // get input tensor
1049 vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
1050 if (inputs.size() != 1) {
1051 MS_LOG(ERROR) << "model's input tensor size: " << inputs.size();
1052 return RET_ERROR;
1053 }
1054 // fp32 inference
1055 for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
1056 for (size_t input_index = 0; input_index < inputs.size(); input_index++) {
1057 STATUS status = calibrator_->GenerateInputData(inputs[input_index]->tensor_name(), i, inputs[input_index]);
1058 if (status != RET_OK) {
1059 MS_LOG(ERROR) << "generate input data from images failed!";
1060 return RET_ERROR;
1061 }
1062 }
1063 // before func
1064 KernelCallBack before_call_back = GetBeforeCallBack(false);
1065 // after func
1066 KernelCallBack after_call_back = GetAfterCallBack(false);
1067 auto status = fp32_session_->RunGraph(before_call_back, after_call_back);
1068 if (status != RET_OK) {
1069 MS_LOG(ERROR) << "run model failed!";
1070 return RET_ERROR;
1071 }
1072 } // end for images
1073
1074 STATUS status = int8_inference.get();
1075 if (status != RET_OK) {
1076 MS_LOG(ERROR) << "int8 inference failed!";
1077 return RET_ERROR;
1078 }
1079 if (calibrator_->GetBatchNum() == 0) {
1080 MS_LOG(ERROR) << "divisor 'calibrate_size' cannot be 0.";
1081 return RET_ERROR;
1082 }
1083 for (auto &key_value : op_bias_diff_map) {
1084 std::for_each(key_value.second.begin(), key_value.second.end(),
1085 [this](float &data) { data = data / calibrator_->GetBatchNum(); });
1086 }
1087 auto cnodes = func_graph->GetOrderedCnodes();
1088 for (auto &cnode : cnodes) {
1089 auto op_name = cnode->fullname_with_scope();
1090 if (op_bias_diff_map.find(op_name) == op_bias_diff_map.end()) {
1091 continue;
1092 }
1093 status = BiasCorrection(func_graph, cnode);
1094 if (status != RET_OK) {
1095 MS_LOG(ERROR) << "do node bias correct failed.";
1096 break;
1097 }
1098 }
1099 return status;
1100 }
1101
BiasCorrection(const FuncGraphPtr & func_graph,const CNodePtr & cnode)1102 STATUS FullQuantQuantizer::BiasCorrection(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
1103 auto op_name = cnode->fullname_with_scope();
1104 const auto &bias_diff = op_bias_diff_map[op_name];
1105 auto primitive = GetValueNode<PrimitivePtr>(cnode->input(0));
1106 if (primitive == nullptr) {
1107 MS_LOG(ERROR) << "primitive is nullptr";
1108 return RET_NULL_PTR;
1109 }
1110 auto quant_param_holder = GetCNodeQuantHolder(primitive);
1111 MS_CHECK_TRUE_MSG(quant_param_holder != nullptr, RET_NULL_PTR, "quant_param_holder is nullptr.");
1112 auto input_quant_params = quant_param_holder->get_input_quant_params();
1113 if (input_quant_params.size() == DIMENSION_3D) {
1114 // compensate the existed
1115 auto bias_quant_params = input_quant_params.at(THIRD_INPUT);
1116 auto bias = cnode->input(FOURTH_INPUT);
1117 auto bias_parameter_ptr = bias->cast<ParameterPtr>();
1118 auto bias_default_param = bias_parameter_ptr->default_param();
1119 auto bias_param = bias_default_param->cast<tensor::TensorPtr>();
1120 int *bias_datas = static_cast<int *>(bias_param->data_c());
1121
1122 if (static_cast<size_t>(bias_param->DataSize()) != bias_diff.size()) {
1123 MS_LOG(DEBUG) << "unexpected bias data count: " << bias_param->DataSize()
1124 << " not the same as bias_diff: " << bias_diff.size();
1125 return RET_ERROR;
1126 }
1127 if (bias_quant_params.size() != bias_diff.size()) {
1128 MS_LOG(ERROR) << "unexpected bias quant params size: " << bias_quant_params.size()
1129 << " not the same as bias_diff: " << bias_diff.size();
1130 return RET_ERROR;
1131 }
1132 for (int i = 0; i < bias_param->DataSize(); i++) {
1133 auto scale = bias_quant_params[i].scale;
1134 if (fabs(scale) <= 0.0f) {
1135 MS_LOG(ERROR) << "divisor 'scale' cannot be 0.";
1136 return RET_ERROR;
1137 }
1138 double after_correct = std::round(bias_diff[i] / scale) + bias_datas[i];
1139 const constexpr int32_t corrected_bias_abs_limit = 0.6 * INT32_MAX;
1140 if (after_correct > corrected_bias_abs_limit) {
1141 MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too large: " << after_correct
1142 << " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i] << " scale: " << scale;
1143 bias_datas[i] = static_cast<int>(corrected_bias_abs_limit);
1144 } else if (after_correct < -corrected_bias_abs_limit) {
1145 MS_LOG(WARNING) << op_name << " ch: " << i << " bias after_corrected too small: " << after_correct
1146 << " origin value: " << bias_datas[i] << " bias_diff: " << bias_diff[i] << " scale: " << scale;
1147 bias_datas[i] = static_cast<int>(-corrected_bias_abs_limit);
1148 } else {
1149 auto diff = static_cast<int>(std::round(bias_diff[i] / scale));
1150 bias_datas[i] += diff;
1151 }
1152 }
1153 } else if (input_quant_params.size() == DIMENSION_2D) {
1154 MS_LOG(INFO) << op_name << " add bias input";
1155 // need to add bias input
1156 auto parameter = func_graph->add_parameter();
1157 if (parameter == nullptr) {
1158 MS_LOG(ERROR) << "parameter is nullptr.";
1159 return RET_NULL_PTR;
1160 }
1161 ShapeVector shape;
1162 shape.push_back(bias_diff.size());
1163
1164 auto tensor_info = CreateTensorInfo(bias_diff.data(), sizeof(float) * bias_diff.size(), shape, kNumberTypeFloat32);
1165 if (tensor_info == nullptr) {
1166 MS_LOG(ERROR) << "create tensor info failed.";
1167 return RET_ERROR;
1168 }
1169 auto status = InitParameterFromTensorInfo(parameter, tensor_info);
1170 if (status != RET_OK) {
1171 MS_LOG(ERROR) << "init parameter from tensor info failed";
1172 return RET_ERROR;
1173 }
1174 parameter->set_name("added_" + op_name + "_bias");
1175 cnode->add_input(parameter);
1176 status = DoBiasQuant(parameter, primitive);
1177 if (status != RET_OK) {
1178 MS_LOG(ERROR) << "Do bias quant failed.";
1179 return RET_ERROR;
1180 }
1181 } else {
1182 MS_LOG(ERROR) << "unexpected get_input_quant_params size: " << input_quant_params.size();
1183 }
1184 return RET_OK;
1185 }
1186
CollectDataFrequency()1187 STATUS FullQuantQuantizer::CollectDataFrequency() {
1188 // get input tensor
1189 vector<mindspore::tensor::MSTensor *> inputs = fp32_session_->GetInputs();
1190 if (inputs.size() != calibrator_->GetInputNum()) {
1191 MS_LOG(ERROR) << "model's input tensor cnt: " << inputs.size() << " != " << calibrator_->GetInputNum();
1192 return RET_ERROR;
1193 }
1194
1195 for (size_t i = 0; i < calibrator_->GetBatchNum(); i++) {
1196 // set multi-input data
1197 for (size_t input_index = 0; input_index < inputs.size(); input_index++) {
1198 STATUS status = calibrator_->GenerateInputData(inputs[input_index]->tensor_name(), i, inputs[input_index]);
1199 if (status != RET_OK) {
1200 MS_LOG(ERROR) << "generate input data from images failed!";
1201 return RET_ERROR;
1202 }
1203 }
1204
1205 KernelCallBack before_callback = [&](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
1206 const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
1207 const CallBackParam &callParam) {
1208 auto diverg_info_map = calibrator_->GetInputDivergInfo();
1209 if (diverg_info_map->find(callParam.node_name) == diverg_info_map->end()) {
1210 return true;
1211 }
1212 if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, before_inputs) != RET_OK) {
1213 return true;
1214 }
1215 int input_i = 0;
1216 for (auto tensor : before_inputs) {
1217 if (tensor->data_type() != kNumberTypeFloat32 || tensor->IsConst()) {
1218 continue;
1219 }
1220 const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
1221 MS_ASSERT(tensor_data != nullptr);
1222 size_t elem_count = tensor->ElementsNum();
1223 vector<float> data(tensor_data, tensor_data + elem_count);
1224 auto ret = this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[callParam.node_name][input_i++]);
1225 if (ret != RET_OK) {
1226 return false;
1227 }
1228 }
1229 return true;
1230 };
1231
1232 KernelCallBack after_callBack = [&](const std::vector<mindspore::tensor::MSTensor *> &after_inputs,
1233 const std::vector<mindspore::tensor::MSTensor *> &after_outputs,
1234 const CallBackParam &call_param) {
1235 auto diverg_info_map = calibrator_->GetOutputDivergInfo();
1236 if (diverg_info_map->find(call_param.node_name) == diverg_info_map->end()) {
1237 return true;
1238 }
1239 if (FullQuantQuantizer::CheckFp32TensorVec(call_param.node_name, after_outputs) != RET_OK) {
1240 return true;
1241 }
1242 int output_i = 0;
1243 // all outputs are same dtype.
1244 for (const auto &tensor : after_outputs) {
1245 const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
1246 MS_ASSERT(tensor_data != nullptr);
1247 size_t elem_count = tensor->ElementsNum();
1248 vector<float> data(tensor_data, tensor_data + elem_count);
1249 auto ret = this->calibrator_->UpdateDataFrequency(data, (*diverg_info_map)[call_param.node_name][output_i++]);
1250 if (ret != RET_OK) {
1251 return false;
1252 }
1253 }
1254 return true;
1255 };
1256 auto status = fp32_session_->RunGraph(before_callback, after_callBack);
1257 if (status != RET_OK) {
1258 MS_LOG(ERROR) << "run model failed!";
1259 return RET_ERROR;
1260 }
1261 }
1262
1263 return RET_OK;
1264 }
1265
ComputeThreshold()1266 STATUS FullQuantQuantizer::ComputeThreshold() { return this->calibrator_->ComputeThreshold(); }
1267
DoQuantize(FuncGraphPtr func_graph)1268 STATUS FullQuantQuantizer::DoQuantize(FuncGraphPtr func_graph) {
1269 MS_LOG(INFO) << "start to parse config file";
1270 if (this->calibrator_ == nullptr) {
1271 MS_LOG(ERROR) << "calibrator is null!";
1272 return RET_ERROR;
1273 }
1274 calibrator_->full_quant_param_ = flags.fullQuantParam;
1275 calibrator_->data_pre_process_param_ = flags.dataPreProcessParam;
1276 if (flags.dataPreProcessParam.calibrate_path.empty()) {
1277 MS_LOG(ERROR) << "calibrate path must pass. The format is input_name_1:input_1_dir,input_name_2:input_2_dir.";
1278 return RET_INPUT_PARAM_INVALID;
1279 }
1280 if (flags.dataPreProcessParam.calibrate_size <= kMinSize || flags.dataPreProcessParam.calibrate_size > kMaxSize) {
1281 MS_LOG(ERROR) << "calibrate size must pass and the size should in [1, 65535].";
1282 return RET_INPUT_PARAM_INVALID;
1283 }
1284 if (flags.dataPreProcessParam.input_type == preprocess::INPUT_TYPE_MAX) {
1285 MS_LOG(ERROR) << "input_type must pass IMAGE | BIN.";
1286 return RET_INPUT_PARAM_INVALID;
1287 }
1288 STATUS status = PreProcess();
1289 if (status != RET_OK) {
1290 MS_LOG(ERROR) << "do pre process failed!";
1291 return status;
1292 }
1293
1294 // anf -- fb
1295 flags.commonQuantParam.quant_type = schema::QuantType_QUANT_NONE;
1296 MS_LOG(INFO) << "start create session";
1297 auto sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
1298 fp32_session_ = sm.session;
1299 fp32_model_ = sm.model;
1300 if (fp32_session_ == nullptr || fp32_model_ == nullptr) {
1301 MS_LOG(ERROR) << "create session failed!";
1302 return RET_ERROR;
1303 }
1304 MS_LOG(INFO) << "start to update divergence's max value";
1305 status = DoInference();
1306 if (status != RET_OK) {
1307 return status;
1308 }
1309 MS_LOG(INFO) << "start to update divergence's interval";
1310 status = UpdateDivergeInterval();
1311 if (status != RET_OK) {
1312 return status;
1313 }
1314 MS_LOG(INFO) << "start to collect data's distribution";
1315 status = CollectDataFrequency();
1316 if (status != RET_OK) {
1317 return status;
1318 }
1319 MS_LOG(INFO) << "compute the best threshold";
1320 status = ComputeThreshold();
1321 if (status != RET_OK) {
1322 return status;
1323 }
1324 MS_LOG(INFO) << "start to generate quant param and quantize tensor's data";
1325 status = QuantNode();
1326 if (status != RET_OK) {
1327 return status;
1328 }
1329
1330 // add quant_cast
1331 quant::QuantCast quant_cast;
1332 status = quant_cast.Run(func_graph);
1333 if (status != RET_OK) {
1334 MS_LOG(ERROR) << "add QuantCast error";
1335 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
1336 return RET_ERROR;
1337 }
1338
1339 if (calibrator_->GetBiasCorrection()) {
1340 // init in8 session
1341 MS_LOG(INFO) << "create quant session";
1342 flags.commonQuantParam.quant_type = schema::QuantType_QUANT_ALL;
1343 auto int8_sm = CreateSessionByFuncGraph(func_graph, flags, calibrator_->GetThreadNum());
1344 int8_session_ = int8_sm.session;
1345 int8_model_ = int8_sm.model;
1346 if (int8_session_ == nullptr || int8_model_ == nullptr) {
1347 MS_LOG(ERROR) << "create session failed!";
1348 return RET_ERROR;
1349 }
1350 MS_LOG(INFO) << "do bias correction";
1351 status = BiasCorrection(func_graph);
1352 if (status != RET_OK) {
1353 MS_LOG(WARNING) << "BiasCorrection failed.";
1354 }
1355 }
1356 return RET_OK;
1357 }
1358
OpInputDataHandle(OperationType type,const string & op_name,std::vector<float> * data)1359 bool FullQuantQuantizer::OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
1360 MS_ASSERT(data != nullptr);
1361 std::lock_guard<std::mutex> lg(mutex_op_input);
1362 if (type == STORE) {
1363 if (fp32_op_input_map.find(op_name) != fp32_op_input_map.end()) {
1364 // the data has not been fetched by int8 model
1365 return false;
1366 }
1367 fp32_op_input_map[op_name] = *data;
1368 return true;
1369 } else if (type == FETCH) {
1370 if (fp32_op_input_map.find(op_name) == fp32_op_input_map.end()) {
1371 // the data not generated by fp32 model yet
1372 return false;
1373 }
1374 *data = fp32_op_input_map[op_name];
1375 fp32_op_input_map.erase(op_name);
1376 return true;
1377 } else {
1378 MS_LOG(ERROR) << "unexpected type: " << type;
1379 }
1380 return false;
1381 }
1382
OpOutputChMeanDataHandle(OperationType type,const string & op_name,std::vector<float> * data)1383 bool FullQuantQuantizer::OpOutputChMeanDataHandle(OperationType type, const string &op_name, std::vector<float> *data) {
1384 MS_ASSERT(data != nullptr);
1385 std::lock_guard<std::mutex> lg(mutex_op_output);
1386 if (type == STORE) {
1387 if (fp32_op_output_ch_mean_map.find(op_name) != fp32_op_output_ch_mean_map.end()) {
1388 // the data has not been fetched by int8 model
1389 return false;
1390 }
1391 fp32_op_output_ch_mean_map[op_name] = *data;
1392 return true;
1393 } else if (type == FETCH) {
1394 if (fp32_op_output_ch_mean_map.find(op_name) == fp32_op_output_ch_mean_map.end()) {
1395 // the data not generated by fp32 model yet
1396 return false;
1397 }
1398 *data = fp32_op_output_ch_mean_map[op_name];
1399 fp32_op_output_ch_mean_map.erase(op_name);
1400 return true;
1401 } else {
1402 MS_LOG(ERROR) << "unexpected type: " << type;
1403 }
1404 return false;
1405 }
1406
GetBeforeCallBack(bool int8_op)1407 KernelCallBack FullQuantQuantizer::GetBeforeCallBack(bool int8_op) {
1408 KernelCallBack before_call_back;
1409 if (!int8_op) {
1410 before_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
1411 const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
1412 const CallBackParam &callParam) -> bool {
1413 if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
1414 if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, before_inputs) != RET_OK) {
1415 return true;
1416 }
1417 auto tensor = before_inputs[0];
1418 MS_ASSERT(tensor != nullptr);
1419 size_t elem_count = tensor->ElementsNum();
1420 std::vector<float> fp32_op_input(elem_count);
1421 auto ret =
1422 memcpy_s(fp32_op_input.data(), fp32_op_input.size() * sizeof(float), tensor->MutableData(), tensor->Size());
1423 if (ret != EOK) {
1424 MS_LOG(ERROR) << "memcpy error: " << ret;
1425 return false;
1426 }
1427 while (!OpInputDataHandle(STORE, callParam.node_name, &fp32_op_input)) {
1428 std::this_thread::sleep_for(std::chrono::milliseconds(kMillisecondsBase));
1429 }
1430 }
1431 return true;
1432 };
1433 } else {
1434 before_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &before_inputs,
1435 const std::vector<mindspore::tensor::MSTensor *> &before_outputs,
1436 const CallBackParam &callParam) -> bool {
1437 if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
1438 vector<float> fp32_op_input;
1439 while (!OpInputDataHandle(FETCH, callParam.node_name, &fp32_op_input)) {
1440 std::this_thread::sleep_for(std::chrono::milliseconds(kMillisecondsBase));
1441 }
1442 auto tensor = before_inputs[0];
1443 MS_ASSERT(tensor != nullptr);
1444 if (tensor->data_type() != kNumberTypeInt8) {
1445 MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
1446 return false;
1447 }
1448 // do quantization: activation is always per layer quantized
1449 std::vector<int8_t> quant_datas;
1450 auto quant_params = tensor->quant_params();
1451 if (quant_params.size() != 1) {
1452 MS_LOG(ERROR) << "unexpected quant_params size: " << quant_params.size();
1453 return false;
1454 }
1455 schema::QuantParamT quant_param_t;
1456 quant_param_t.scale = quant_params[0].scale;
1457 quant_param_t.zeroPoint = quant_params[0].zeroPoint;
1458 for (auto float_data : fp32_op_input) {
1459 auto quant_data = QuantizeData<int8_t>(float_data, &quant_param_t, quant_max, quant_min);
1460 quant_datas.push_back(quant_data);
1461 }
1462
1463 if (tensor->Size() != quant_datas.size() * sizeof(int8_t)) {
1464 MS_LOG(ERROR) << "unexpected tensor size: " << quant_datas.size()
1465 << " not the same with: " << quant_datas.size() * sizeof(int8_t);
1466 return false;
1467 }
1468
1469 auto ret =
1470 memcpy_s(tensor->MutableData(), tensor->Size(), quant_datas.data(), quant_datas.size() * sizeof(int8_t));
1471 if (ret != EOK) {
1472 MS_LOG(ERROR) << "memcpy error: " << ret;
1473 return false;
1474 }
1475 }
1476 return true;
1477 };
1478 }
1479 return before_call_back;
1480 }
1481
GetAfterCallBack(bool int8_op)1482 KernelCallBack FullQuantQuantizer::GetAfterCallBack(bool int8_op) {
1483 KernelCallBack after_call_back;
1484 if (!int8_op) {
1485 return GetFloatAfterCallBack();
1486 }
1487 return GetInt8AfterCallBack();
1488 }
1489
GetInt8AfterCallBack()1490 KernelCallBack FullQuantQuantizer::GetInt8AfterCallBack() {
1491 KernelCallBack after_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
1492 const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
1493 const CallBackParam &callParam) -> bool {
1494 if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
1495 vector<float> fp32_op_output_ch_mean;
1496 while (!OpOutputChMeanDataHandle(FETCH, callParam.node_name, &fp32_op_output_ch_mean)) {
1497 std::this_thread::sleep_for(std::chrono::milliseconds(kMillisecondsBase));
1498 }
1499 auto tensor = afterOutputs[0];
1500 MS_ASSERT(tensor != nullptr);
1501 if (tensor->data_type() != kNumberTypeInt8) {
1502 MS_LOG(ERROR) << "unexpected tensor type: " << tensor->data_type();
1503 return false;
1504 }
1505 const int8_t *tensor_data = static_cast<int8_t *>(tensor->MutableData());
1506 size_t elem_count = tensor->ElementsNum();
1507 auto shapes = tensor->shape();
1508 if (shapes.size() != 4) {
1509 MS_LOG(ERROR) << "unexpected shape size: " << shapes.size();
1510 return false;
1511 }
1512 // suppose the the format is NHWC
1513 auto channels = shapes[3];
1514 if (channels == 0) {
1515 MS_LOG(ERROR) << "unexpected channels: 0";
1516 return false;
1517 }
1518 auto quant_params = tensor->quant_params();
1519 if (quant_params.size() != 1) {
1520 MS_LOG(ERROR) << "unexpected activatation quant_params size: " << quant_params.size();
1521 return false;
1522 }
1523 auto scale = quant_params[0].scale;
1524 auto zp = quant_params[0].zeroPoint;
1525 std::vector<float> dequant_op_output_ch_mean(channels);
1526 auto one_filter_size = elem_count / channels;
1527 for (int i = 0; i < channels; i++) {
1528 float sum = 0;
1529 for (size_t j = 0; j < one_filter_size; j++) {
1530 auto index = j * channels + i;
1531 if (index >= elem_count) {
1532 MS_LOG(ERROR) << "over flow!";
1533 return false;
1534 }
1535 // deuqant activation
1536 auto float_data = scale * (tensor_data[index] - zp);
1537 sum += float_data;
1538 }
1539 if (one_filter_size == 0) {
1540 MS_LOG(ERROR) << "divisor 'one_filter_size' cannot be 0.";
1541 return false;
1542 }
1543 sum = sum / one_filter_size;
1544 dequant_op_output_ch_mean[i] = sum;
1545 }
1546 std::transform(fp32_op_output_ch_mean.begin(), fp32_op_output_ch_mean.end(), dequant_op_output_ch_mean.begin(),
1547 dequant_op_output_ch_mean.begin(), std::minus<>());
1548
1549 if (op_bias_diff_map.find(callParam.node_name) != op_bias_diff_map.end()) {
1550 auto &bias_diff = op_bias_diff_map[callParam.node_name];
1551 std::transform(bias_diff.begin(), bias_diff.end(), dequant_op_output_ch_mean.begin(), bias_diff.begin(),
1552 std::plus<>());
1553 } else {
1554 op_bias_diff_map[callParam.node_name] = dequant_op_output_ch_mean;
1555 }
1556 }
1557 return true;
1558 };
1559 return after_call_back;
1560 }
1561
GetFloatAfterCallBack()1562 KernelCallBack FullQuantQuantizer::GetFloatAfterCallBack() {
1563 KernelCallBack after_call_back = [this](const std::vector<mindspore::tensor::MSTensor *> &afterInputs,
1564 const std::vector<mindspore::tensor::MSTensor *> &afterOutputs,
1565 const CallBackParam &callParam) -> bool {
1566 if (callParam.node_type == kTypeConv2D || callParam.node_type == kTypeDepthwiseConv2D) {
1567 if (FullQuantQuantizer::CheckFp32TensorVec(callParam.node_name, afterOutputs) != RET_OK) {
1568 return true;
1569 }
1570 auto tensor = afterOutputs[0];
1571 MS_ASSERT(tensor != nullptr);
1572 const auto *tensor_data = static_cast<const float *>(tensor->MutableData());
1573 size_t elem_count = tensor->ElementsNum();
1574 auto shapes = tensor->shape();
1575 if (shapes.size() != 4) {
1576 MS_LOG(ERROR) << "unexpected shape size: " << shapes.size();
1577 return false;
1578 }
1579 // suppose the activation format: NHWC
1580 auto channels = shapes[3];
1581 if (channels == 0) {
1582 MS_LOG(ERROR) << "unexpected channels: 0";
1583 return false;
1584 }
1585 std::vector<float> fp32_op_output_ch_mean(channels);
1586 auto one_filter_size = elem_count / channels;
1587 for (int i = 0; i < channels; i++) {
1588 float sum = 0;
1589 for (size_t j = 0; j < one_filter_size; j++) {
1590 auto index = j * channels + i;
1591 if (index >= elem_count) {
1592 MS_LOG(ERROR) << "over flow!";
1593 return false;
1594 }
1595 sum += tensor_data[index];
1596 }
1597 if (one_filter_size == 0) {
1598 MS_LOG(ERROR) << "divisor 'one_filter_size' cannot be 0.";
1599 return false;
1600 }
1601 sum = sum / one_filter_size;
1602 fp32_op_output_ch_mean[i] = sum;
1603 }
1604 while (!OpOutputChMeanDataHandle(STORE, callParam.node_name, &fp32_op_output_ch_mean)) {
1605 std::this_thread::sleep_for(std::chrono::milliseconds(kMillisecondsBase));
1606 }
1607 }
1608 return true;
1609 };
1610 return after_call_back;
1611 }
1612 } // namespace mindspore::lite::quant
1613