1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18
19 #include "tools/converter/quantizer/quantization_optimizer.h"
20 #include <memory>
21 #include <string>
22 #include <deque>
23 #include <map>
24 #include <set>
25 #include "tools/optimizer/graph/redundant_op_remove_pass.h"
26 #include "tools/lite_exporter/fetch_content.h"
27 #include "base/base.h"
28 #include "tools/converter/quantizer/quantize_util.h"
29 #include "tools/converter/quantizer/weight_quantizer.h"
30 #include "tools/converter/quantizer/full_quant_quantizer.h"
31 #include "tools/converter/quantizer/debug_info_manager.h"
32 #include "tools/converter/quantizer/parameter_tunner.h"
33 #include "tools/converter/quantizer/dynamic_quantizer.h"
34 #include "tools/lite_exporter/anf_exporter.h"
35 #include "tools/converter/quantizer/cle_strategy.h"
36 #include "tools/optimizer/common/pass_manager_extends.h"
37 #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
38 #include "include/backend/optimizer/graph_optimizer.h"
39 #include "tools/optimizer/graph/infershape_pass.h"
40 #include "tools/converter/quantizer/split_shared_bias.h"
41
42 namespace mindspore::lite::quant {
DoFullQuant(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)43 int QuantizationOptimizer::DoFullQuant(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
44 auto quantizer = std::make_unique<FullQuantQuantizer>(param);
45 if (quantizer == nullptr) {
46 MS_LOG(ERROR) << "New FullQuantQuantizer failed";
47 return RET_ERROR;
48 }
49 auto status = quantizer->DoQuantize(old_graph);
50 if (status != RET_OK) {
51 MS_LOG(ERROR) << "DoQuantization failed " << status;
52 return RET_ERROR;
53 }
54 return RET_OK;
55 }
56
DoWeightQuant(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)57 int QuantizationOptimizer::DoWeightQuant(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
58 double init_scale = param->mixedBitWeightQuantParam.init_scale;
59 if (param->commonQuantParam.bit_num == 0 && param->mixedBitWeightQuantParam.auto_tune) {
60 ParameterOptimizer optimizer;
61 auto status = optimizer.GridSearchForScale(old_graph, param, &init_scale);
62 if (status != RET_OK) {
63 MS_LOG(ERROR) << "Grid search with scale failed.";
64 return status;
65 }
66 auto quantizer = std::make_unique<WeightQuantizer>(param, init_scale);
67 if (quantizer == nullptr) {
68 MS_LOG(ERROR) << "New WeightQuantizer failed";
69 return RET_ERROR;
70 }
71 status = static_cast<WeightQuantizer *>(quantizer.get())->DoQuantize(old_graph);
72 if (status != RET_OK) {
73 MS_LOG(ERROR) << "DoQuantization failed " << status;
74 return RET_ERROR;
75 }
76 } else {
77 auto quantizer = std::make_unique<WeightQuantizer>(param);
78 if (quantizer == nullptr) {
79 MS_LOG(ERROR) << "New WeightQuantizer failed";
80 return RET_ERROR;
81 }
82 auto status = quantizer->DoQuantize(old_graph);
83 if (status != RET_OK) {
84 MS_LOG(ERROR) << "DoQuantization failed " << status;
85 return RET_ERROR;
86 }
87 }
88 return RET_OK;
89 }
90
DoDynamicQuant(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)91 int DoDynamicQuant(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
92 auto quantizer = std::make_unique<DynamicQuantizer>(param);
93 if (quantizer == nullptr) {
94 MS_LOG(ERROR) << "New DynamicQuantizer failed";
95 return RET_ERROR;
96 }
97 auto status = quantizer->DoQuantize(old_graph);
98 if (status != RET_OK) {
99 MS_LOG(ERROR) << "DoQuantization failed " << status;
100 return RET_ERROR;
101 }
102 return RET_OK;
103 }
104
ParseLiteModel(const FuncGraphPtr & func_graph,const std::shared_ptr<ConverterPara> & param)105 std::shared_ptr<lite::Model> ParseLiteModel(const FuncGraphPtr &func_graph,
106 const std::shared_ptr<ConverterPara> ¶m) {
107 FuncGraphPtr func_graph_clone;
108 if (CloneFuncGraph(func_graph, param, &func_graph_clone) != RET_OK) {
109 MS_LOG(ERROR) << "Clone func_graph failed";
110 return nullptr;
111 }
112 auto meta_graph = Export(func_graph_clone, true, true);
113 if (meta_graph == nullptr) {
114 MS_LOG(ERROR) << "Export to meta_graph failed";
115 return nullptr;
116 }
117
118 // transform
119 GraphDefTransform fb_transform;
120 fb_transform.SetGraphDef(meta_graph);
121 auto status = fb_transform.Transform(param);
122 if (status != RET_OK) {
123 MS_LOG(ERROR) << "FBTransform model failed";
124 delete meta_graph;
125 return nullptr;
126 }
127 meta_graph->version = Version();
128
129 flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
130 auto offset = schema::MetaGraph::Pack(builder, meta_graph);
131 builder.Finish(offset);
132 schema::FinishMetaGraphBuffer(builder, offset);
133 size_t size = builder.GetSize();
134 auto content = reinterpret_cast<const char *>(builder.GetBufferPointer());
135 if (content == nullptr) {
136 MS_LOG(ERROR) << "GetBufferPointer nullptr";
137 delete meta_graph;
138 return nullptr;
139 }
140 delete meta_graph;
141 return std::shared_ptr<lite::Model>(LiteModel::Import(content, size));
142 }
143
DoQuantDebug(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param,const std::shared_ptr<mindspore::Model> & origin_model,const std::shared_ptr<lite::Model> & origin_lite_model)144 int DoQuantDebug(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m,
145 const std::shared_ptr<mindspore::Model> &origin_model,
146 const std::shared_ptr<lite::Model> &origin_lite_model) {
147 auto quant_model = std::make_shared<mindspore::Model>();
148 CHECK_NULL_RETURN(quant_model);
149 size_t size = 0;
150 auto status = BuildModelByFuncGraph(quant_model, old_graph, param, &size);
151 if (status != kSuccess) {
152 MS_LOG(ERROR) << "Build model failed";
153 return RET_ERROR;
154 }
155 std::map<std::string, OpParameter *> op_parameters;
156 auto ret = FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
157 if (ret != RET_OK) {
158 MS_LOG(ERROR) << "Fetch op parameter from funcgraph failed";
159 return ret;
160 }
161 DebugInfoManager manager;
162
163 auto quant_lite_model = ParseLiteModel(old_graph, param);
164 if (quant_lite_model == nullptr) {
165 MS_LOG(ERROR) << "Parse quant lite model failed";
166 return RET_ERROR;
167 }
168 if (origin_lite_model == nullptr) {
169 MS_LOG(ERROR) << "Origin lite model nullptr.";
170 return RET_ERROR;
171 }
172
173 ret = manager.CompareOriginWithQuant(origin_model, quant_model, op_parameters, param, origin_lite_model,
174 quant_lite_model);
175 auto free_buffer = [&] {
176 for (auto parameter : op_parameters) {
177 if (parameter.second != nullptr) {
178 free(parameter.second);
179 parameter.second = nullptr;
180 }
181 }
182 op_parameters.clear();
183 };
184 if (ret != RET_OK) {
185 MS_LOG(ERROR) << "Compare origin with quant failed.";
186 free_buffer();
187 return ret;
188 }
189 free_buffer();
190 return RET_OK;
191 }
192
ConvertValueNodeToParameter(const FuncGraphPtr & func_graph)193 int ConvertValueNodeToParameter(const FuncGraphPtr &func_graph) {
194 auto cnodes = func_graph->GetOrderedCnodes();
195 for (auto &cnode : cnodes) {
196 for (size_t i = kPrimOffset; i < cnode->size(); ++i) {
197 auto input = cnode->input(i);
198 if (!input->isa<ValueNode>()) {
199 continue;
200 }
201 auto tensor_info = input->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>();
202 if (tensor_info == nullptr) {
203 MS_LOG(INFO) << cnode->fullname_with_scope() << " input index: " << i << " cast tensor nullptr.";
204 continue;
205 }
206 auto parameter = func_graph->add_parameter();
207 auto status = InitParameterFromTensorInfo(parameter, tensor_info);
208 if (status != RET_OK) {
209 MS_LOG(ERROR) << "Init parameter From tensor failed, tenor: " << tensor_info->name();
210 return status;
211 }
212 parameter->set_name(input->fullname_with_scope());
213 auto manage = Manage(func_graph);
214 manage->Replace(input, parameter);
215 }
216 }
217 return RET_OK;
218 }
219
PrepareQuantize(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)220 int QuantizationOptimizer::PrepareQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
221 if (!param->train_model && param->save_type == kMindIR) {
222 auto status = ConvertValueNodeToParameter(old_graph);
223 if (status != RET_OK) {
224 MS_LOG(ERROR) << "Convert value node To parameter failed.";
225 return status;
226 }
227 }
228
229 auto convert_pm = std::make_shared<opt::LitePassManager>("anf graph convert pass manager", true);
230 convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
231 auto optimizer = std::make_shared<opt::GraphOptimizer>();
232 optimizer->AddPassManager(convert_pm);
233 if (optimizer->Optimize(old_graph) == nullptr) {
234 MS_LOG(ERROR) << "run graph pass failed";
235 return RET_ERROR;
236 }
237
238 bool per_layer = param->commonQuantParam.quant_type == quant::QUANT_ALL && !param->fullQuantParam.per_channel &&
239 param->fullQuantParam.target_device != DSP;
240 if (per_layer) {
241 CLEStrategy cle_strategy(old_graph);
242 auto status = cle_strategy.Run();
243 if (status != RET_OK) {
244 MS_LOG(ERROR) << "do cle_strategy failed!";
245 return status;
246 }
247 }
248
249 if (param->commonQuantParam.quant_type == quant::QUANT_ALL && param->fullQuantParam.bias_correction) {
250 SplitSharedBias split_shared_bias(old_graph, param);
251 if (split_shared_bias.Run() != RET_OK) {
252 MS_LOG(ERROR) << "split shared bias node failed!";
253 return RET_ERROR;
254 }
255 }
256 return RET_OK;
257 }
258
DoSingleGraphQuantize(const FuncGraphPtr & func_graph,const std::shared_ptr<ConverterPara> & param)259 int QuantizationOptimizer::DoSingleGraphQuantize(const FuncGraphPtr &func_graph,
260 const std::shared_ptr<ConverterPara> ¶m) {
261 CHECK_NULL_RETURN(param);
262 int status = PrepareQuantize(func_graph, param);
263 if (status != RET_OK) {
264 MS_LOG(ERROR) << "PrepareQuantize failed.";
265 return status;
266 }
267
268 std::shared_ptr<mindspore::Model> origin;
269 std::shared_ptr<lite::Model> origin_lite_model;
270 if (param->commonQuantParam.is_debug) { // Bak fp32 model for debug
271 auto quant_type = param->commonQuantParam.quant_type;
272 param->commonQuantParam.quant_type = quant::QUANT_NONE;
273 origin = std::make_shared<mindspore::Model>();
274 CHECK_NULL_RETURN(origin);
275 size_t size = 0;
276 auto ret = BuildModelByFuncGraph(origin, func_graph, param, &size);
277 param->commonQuantParam.quant_type = quant_type;
278 if (ret != kSuccess) {
279 MS_LOG(ERROR) << "Build model failed";
280 return RET_ERROR;
281 }
282 origin_lite_model = ParseLiteModel(func_graph, param);
283 if (origin_lite_model == nullptr) {
284 MS_LOG(ERROR) << "Parse lite model failed.";
285 return RET_ERROR;
286 }
287 }
288 if (param->commonQuantParam.quant_type == quant::QUANT_ALL) { // Full Quantization
289 status = ConvertFp16ToFp32(func_graph);
290 if (status != RET_OK) {
291 MS_LOG(ERROR) << "Converter fp16 to fp32 failed.";
292 return status;
293 }
294 status = DoFullQuant(func_graph, param);
295 if (status != RET_OK) {
296 MS_LOG(ERROR) << "Do full quant failed.";
297 return status;
298 }
299 } else if (param->commonQuantParam.quant_type == quant::QUANT_WEIGHT) { // Weight Quantization
300 status = DoWeightQuant(func_graph, param);
301 if (status != RET_OK) {
302 MS_LOG(ERROR) << "Do weight quant failed.";
303 return status;
304 }
305 } else if (param->commonQuantParam.quant_type == quant::QUANT_DYNAMIC) { // Dynamic Quantization
306 status = DoDynamicQuant(func_graph, param);
307 if (status != RET_OK) {
308 MS_LOG(ERROR) << "Do dynamic quant failed.";
309 return status;
310 }
311 }
312
313 if (param->fullQuantParam.target_device != ASCEND) {
314 auto optimizer = std::make_shared<opt::GraphOptimizer>();
315 CHECK_NULL_RETURN(optimizer);
316 auto fusion_pm = std::make_shared<opt::LitePassManager>("fusion pass manager after quant", false);
317 CHECK_NULL_RETURN(fusion_pm);
318 fusion_pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>());
319 optimizer->AddPassManager(fusion_pm);
320 if (optimizer->Optimize(func_graph) == nullptr) {
321 MS_LOG(ERROR) << "run cast node fusion failed.";
322 return RET_ERROR;
323 }
324 }
325
326 if (param->commonQuantParam.is_debug) {
327 status = DoQuantDebug(func_graph, param, origin, origin_lite_model);
328 if (status != RET_OK) {
329 MS_LOG(ERROR) << "Do quant debug failed.";
330 return status;
331 }
332 }
333 return RET_OK;
334 }
335
Run(const mindspore::FuncGraphPtr & func_graph)336 int QuantizationOptimizer::Run(const mindspore::FuncGraphPtr &func_graph) {
337 if (param_->commonQuantParam.quant_type == quant::QUANT_NONE || param_->fullQuantParam.target_device == ASCEND) {
338 return RET_OK;
339 }
340 // set manager
341 if (func_graph->manager() == nullptr) {
342 auto root_func_manager = Manage(func_graph);
343 std::set<FuncGraphPtr> all_func_graphs = {};
344 lite::GetAllFuncGraph(func_graph, &all_func_graphs);
345 for (auto graph : all_func_graphs) {
346 graph->set_manager(root_func_manager);
347 }
348 }
349
350 std::set<FuncGraphPtr> all_func_graphs{};
351 quant::GetFuncGraphs(func_graph, &all_func_graphs);
352 // Support for multi-subgraph models
353 for (auto &item : all_func_graphs) {
354 auto status = DoSingleGraphQuantize(item, param_);
355 if (status != RET_OK) {
356 MS_LOG(ERROR) << "Do Quantize failed.";
357 return status;
358 }
359 }
360 if (param_->fullQuantParam.target_device != ASCEND) {
361 auto optimizer = std::make_shared<opt::GraphOptimizer>();
362 CHECK_NULL_RETURN(optimizer);
363 auto fusion_pm = std::make_shared<opt::LitePassManager>("fusion pass manager after quant", false);
364 CHECK_NULL_RETURN(fusion_pm);
365 fusion_pm->AddPass(std::make_shared<opt::InferShapePass>(param_->fmk_type, param_->train_model));
366 optimizer->AddPassManager(fusion_pm);
367 if (optimizer->Optimize(func_graph) == nullptr) {
368 MS_LOG(ERROR) << "run infershape failed.";
369 return RET_ERROR;
370 }
371 }
372 return RET_OK;
373 }
374 } // namespace mindspore::lite::quant
375