• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H
18 #define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H
19 
20 #include <string>
21 #include <memory>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 #include <cfloat>
26 #include <map>
27 #include "ops/primitive_c.h"
28 #include "schema/inner/model_generated.h"
29 #include "src/lite_session.h"
30 #include "tools/converter/quantizer/quantizer.h"
31 #include "tools/converter/converter.h"
32 #include "include/ms_tensor.h"
33 #include "tools/converter/quantizer/quantize_util.h"
34 #include "tools/converter/quantizer/quant_params.h"
35 #include "tools/converter/preprocess/preprocess_param.h"
36 
37 namespace mindspore::lite::quant {
38 class Calibrator;
39 constexpr int kDefaultBinNumber = 2048;
40 struct DivergInfo {
41   std::vector<float> histogram;
42   CNodePtr cnode;
43   int bin_num = 0;
44   float interval = 0;
45   float max = 0.0f;
46   float min = 0.0f;
47   float best_T = 0.0f;
48   size_t bit_num = 0;
49   int quant_max = 255;
50   int quant_min = 0;
51   ActivationQuantizedMethod activation_quant_method = MAX_MIN;
52   std::vector<float> min_datas;
53   std::vector<float> max_datas;
54   std::pair<float, float> percent_result{0.0, 0.0};
55   float scale_tmp = 0;
56   DivergInfo() = default;
DivergInfoDivergInfo57   DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min,
58              ActivationQuantizedMethod activation_quant_method) {
59     this->activation_quant_method = activation_quant_method;
60     this->cnode = std::move(cnode);
61     this->bin_num = bins;
62     this->bit_num = bits;
63     histogram.resize(bin_num);
64     max = -FLT_MAX;
65     min = FLT_MAX;
66     this->quant_max = quant_max;
67     this->quant_min = quant_min;
68     std::fill(histogram.begin(), histogram.end(), 1.0e-7);
69   }
70 
71   STATUS RecordMaxMinValue(const std::vector<float> &data);
72 
73   STATUS RecordMaxMinValueArray(const std::vector<float> &data);
74 
75   void UpdateInterval();
76 
77   STATUS UpdateHistogram(const std::vector<float> &data);
78 
79   void DumpHistogram();
80 
81   void HandleBinForKL(int quant_bint_nums, int bin_index, std::vector<float> *quantized_histogram,
82                       std::vector<float> *expanded_histogram);
83 
84   STATUS ComputeThreshold();
85 
86   std::pair<CNodePtr, float> GetScale();
87 
88   std::pair<CNodePtr, int32_t> GetZeropoint();
89 };
90 
91 class FullQuantQuantizer : public Quantizer {
92  public:
93   FullQuantQuantizer(FuncGraphPtr graph, int bit_num, TypeId target_type = kNumberTypeInt8, bool per_channel = true);
94   ~FullQuantQuantizer() override;
95 
96   STATUS DoQuantize(FuncGraphPtr func_graph) override;
97 
98   size_t bit_num;
99   int quant_max{INT8_MAX};
100   int quant_min{INT8_MIN};
101 
102  private:
103   bool per_channel_{true};
104   TypeId target_type_{kNumberTypeInt8};
105   std::unique_ptr<Calibrator> calibrator_{nullptr};
106 
107   session::LiteSession *fp32_session_{nullptr};
108   Model *fp32_model_{nullptr};
109   session::LiteSession *int8_session_{nullptr};
110   Model *int8_model_{nullptr};
111 
112   std::map<std::string, std::vector<float>> fp32_op_input_map;           // concurrency
113   std::map<std::string, std::vector<float>> fp32_op_output_ch_mean_map;  // concurrency
114   std::map<std::string, std::vector<float>> op_bias_diff_map;            // only use by int8 model
115   std::mutex mutex_op_input;
116   std::mutex mutex_op_output;
117 
118   enum OperationType {
119     STORE,
120     FETCH,
121   };
122 
123   bool OpInputDataHandle(OperationType type, const string &op_name, std::vector<float> *data);
124   bool OpOutputChMeanDataHandle(OperationType type, const string &op_name, std::vector<float> *data);
125 
126   const std::string kTypeConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_Conv2DFusion);
127   const std::string kTypeDepthwiseConv2D = schema::EnumNamePrimitiveType(schema::PrimitiveType_Conv2DFusion);
128   const std::string kTypeConcat = schema::EnumNamePrimitiveType(schema::PrimitiveType_Concat);
129 
130   STATUS PreProcess();
131 
132   static STATUS CheckFp32TensorVec(const std::string &node_name,
133                                    const std::vector<mindspore::tensor::MSTensor *> &tensor_vec);
134 
135   STATUS DoInference();
136 
137   STATUS UpdateDivergeInterval();
138 
139   STATUS CollectDataFrequency();
140 
141   STATUS ComputeThreshold();
142 
143   STATUS QuantNodeSimpleOp(const CNodePtr &cnode);
144 
145   STATUS QuantNode();
146 
147   STATUS SetInOutQuantParam(const AnfNodePtr &input_node, const std::unique_ptr<DivergInfo> &info,
148                             const PrimitivePtr &primitive, bool is_input, size_t index) const;
149 
150   STATUS DoWeightQuant(const std::string &op_name, const AnfNodePtr &weight, const PrimitivePtr &primitive,
151                        bool per_channel, int input_index) const;
152 
153   STATUS DoParameterNodeQuant(const CNodePtr &cnode, const AnfNodePtr &input_node, size_t input_index);
154 
155   static STATUS DoBiasQuant(const AnfNodePtr &bias, const PrimitivePtr &primitive);
156   STATUS Int8Inference();
157   STATUS BiasCorrection(const FuncGraphPtr &func_graph);
158   STATUS BiasCorrection(const FuncGraphPtr &func_graph, const CNodePtr &cnode);
159   KernelCallBack GetBeforeCallBack(bool int8_op);
160   KernelCallBack GetAfterCallBack(bool int8_op);
161   KernelCallBack GetInt8AfterCallBack();
162   KernelCallBack GetFloatAfterCallBack();
163 };
164 
165 class Calibrator {
166  public:
Calibrator(size_t bit_num,int quant_max,int quant_min)167   explicit Calibrator(size_t bit_num, int quant_max, int quant_min)
168       : bit_num_(bit_num), quant_max_(quant_max), quant_min_(quant_min) {}
169 
170   ~Calibrator() = default;
171 
172   STATUS GenerateInputData(const std::string &input_name, size_t image_index,
173                            mindspore::tensor::MSTensor *tensor) const;
174 
GetBatchNum()175   size_t GetBatchNum() const { return data_pre_process_param_.calibrate_size; }
176 
GetThreadNum()177   uint32_t GetThreadNum() const { return full_quant_param_.thread_num; }
178 
GetBiasCorrection()179   bool GetBiasCorrection() const { return full_quant_param_.bias_correction; }
180 
GetInputNum()181   size_t GetInputNum() const { return data_pre_process_param_.calibrate_path_vector.size(); }
182 
183   STATUS AddQuantizedOp(const CNodePtr &cnode);
184 
185   STATUS RecordMaxMinValue(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
186 
187   STATUS UpdateDivergInterval(std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *diverg_info);
188 
189   STATUS UpdateDataFrequency(const std::vector<float> &data, const std::unique_ptr<DivergInfo> &diverg_info);
190 
191   STATUS ComputeThreshold();
192 
193   std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetInputDivergInfo();
194 
195   std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> *GetOutputDivergInfo();
196 
197   FullQuantParam full_quant_param_;
198 
199   preprocess::DataPreProcessParam data_pre_process_param_;
200 
201  private:
202   std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> inputs_diverg_info_;
203 
204   std::unordered_map<std::string, std::vector<std::unique_ptr<DivergInfo>>> outputs_diverg_info_;
205 
206   size_t bit_num_;
207   int quant_max_;
208   int quant_min_;
209 };
210 }  // namespace mindspore::lite::quant
211 #endif  // MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER_FULL_QUANT_QUANTIZER_H
212