• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 
19 #include "tools/converter/quantizer/quantization_optimizer.h"
20 #include <memory>
21 #include <string>
22 #include <deque>
23 #include <map>
24 #include <set>
25 #include "tools/optimizer/graph/redundant_op_remove_pass.h"
26 #include "tools/lite_exporter/fetch_content.h"
27 #include "base/base.h"
28 #include "tools/converter/quantizer/quantize_util.h"
29 #include "tools/converter/quantizer/weight_quantizer.h"
30 #include "tools/converter/quantizer/full_quant_quantizer.h"
31 #include "tools/converter/quantizer/debug_info_manager.h"
32 #include "tools/converter/quantizer/parameter_tunner.h"
33 #include "tools/converter/quantizer/dynamic_quantizer.h"
34 #include "tools/lite_exporter/anf_exporter.h"
35 #include "tools/converter/quantizer/cle_strategy.h"
36 #include "tools/optimizer/common/pass_manager_extends.h"
37 #include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
38 #include "include/backend/optimizer/graph_optimizer.h"
39 #include "tools/optimizer/graph/infershape_pass.h"
40 #include "tools/converter/quantizer/split_shared_bias.h"
41 
42 namespace mindspore::lite::quant {
DoFullQuant(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)43 int QuantizationOptimizer::DoFullQuant(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
44   auto quantizer = std::make_unique<FullQuantQuantizer>(param);
45   if (quantizer == nullptr) {
46     MS_LOG(ERROR) << "New FullQuantQuantizer failed";
47     return RET_ERROR;
48   }
49   auto status = quantizer->DoQuantize(old_graph);
50   if (status != RET_OK) {
51     MS_LOG(ERROR) << "DoQuantization failed " << status;
52     return RET_ERROR;
53   }
54   return RET_OK;
55 }
56 
DoWeightQuant(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)57 int QuantizationOptimizer::DoWeightQuant(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
58   double init_scale = param->mixedBitWeightQuantParam.init_scale;
59   if (param->commonQuantParam.bit_num == 0 && param->mixedBitWeightQuantParam.auto_tune) {
60     ParameterOptimizer optimizer;
61     auto status = optimizer.GridSearchForScale(old_graph, param, &init_scale);
62     if (status != RET_OK) {
63       MS_LOG(ERROR) << "Grid search with scale failed.";
64       return status;
65     }
66     auto quantizer = std::make_unique<WeightQuantizer>(param, init_scale);
67     if (quantizer == nullptr) {
68       MS_LOG(ERROR) << "New WeightQuantizer failed";
69       return RET_ERROR;
70     }
71     status = static_cast<WeightQuantizer *>(quantizer.get())->DoQuantize(old_graph);
72     if (status != RET_OK) {
73       MS_LOG(ERROR) << "DoQuantization failed " << status;
74       return RET_ERROR;
75     }
76   } else {
77     auto quantizer = std::make_unique<WeightQuantizer>(param);
78     if (quantizer == nullptr) {
79       MS_LOG(ERROR) << "New WeightQuantizer failed";
80       return RET_ERROR;
81     }
82     auto status = quantizer->DoQuantize(old_graph);
83     if (status != RET_OK) {
84       MS_LOG(ERROR) << "DoQuantization failed " << status;
85       return RET_ERROR;
86     }
87   }
88   return RET_OK;
89 }
90 
DoDynamicQuant(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)91 int DoDynamicQuant(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
92   auto quantizer = std::make_unique<DynamicQuantizer>(param);
93   if (quantizer == nullptr) {
94     MS_LOG(ERROR) << "New DynamicQuantizer failed";
95     return RET_ERROR;
96   }
97   auto status = quantizer->DoQuantize(old_graph);
98   if (status != RET_OK) {
99     MS_LOG(ERROR) << "DoQuantization failed " << status;
100     return RET_ERROR;
101   }
102   return RET_OK;
103 }
104 
ParseLiteModel(const FuncGraphPtr & func_graph,const std::shared_ptr<ConverterPara> & param)105 std::shared_ptr<lite::Model> ParseLiteModel(const FuncGraphPtr &func_graph,
106                                             const std::shared_ptr<ConverterPara> &param) {
107   FuncGraphPtr func_graph_clone;
108   if (CloneFuncGraph(func_graph, param, &func_graph_clone) != RET_OK) {
109     MS_LOG(ERROR) << "Clone func_graph failed";
110     return nullptr;
111   }
112   auto meta_graph = Export(func_graph_clone, true, true);
113   if (meta_graph == nullptr) {
114     MS_LOG(ERROR) << "Export to meta_graph failed";
115     return nullptr;
116   }
117 
118   // transform
119   GraphDefTransform fb_transform;
120   fb_transform.SetGraphDef(meta_graph);
121   auto status = fb_transform.Transform(param);
122   if (status != RET_OK) {
123     MS_LOG(ERROR) << "FBTransform model failed";
124     delete meta_graph;
125     return nullptr;
126   }
127   meta_graph->version = Version();
128 
129   flatbuffers::FlatBufferBuilder builder(kMaxNum1024);
130   auto offset = schema::MetaGraph::Pack(builder, meta_graph);
131   builder.Finish(offset);
132   schema::FinishMetaGraphBuffer(builder, offset);
133   size_t size = builder.GetSize();
134   auto content = reinterpret_cast<const char *>(builder.GetBufferPointer());
135   if (content == nullptr) {
136     MS_LOG(ERROR) << "GetBufferPointer nullptr";
137     delete meta_graph;
138     return nullptr;
139   }
140   delete meta_graph;
141   return std::shared_ptr<lite::Model>(LiteModel::Import(content, size));
142 }
143 
DoQuantDebug(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param,const std::shared_ptr<mindspore::Model> & origin_model,const std::shared_ptr<lite::Model> & origin_lite_model)144 int DoQuantDebug(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param,
145                  const std::shared_ptr<mindspore::Model> &origin_model,
146                  const std::shared_ptr<lite::Model> &origin_lite_model) {
147   auto quant_model = std::make_shared<mindspore::Model>();
148   CHECK_NULL_RETURN(quant_model);
149   size_t size = 0;
150   auto status = BuildModelByFuncGraph(quant_model, old_graph, param, &size);
151   if (status != kSuccess) {
152     MS_LOG(ERROR) << "Build model failed";
153     return RET_ERROR;
154   }
155   std::map<std::string, OpParameter *> op_parameters;
156   auto ret = FetchOpParameterFromFuncGraph(old_graph, &op_parameters);
157   if (ret != RET_OK) {
158     MS_LOG(ERROR) << "Fetch op parameter from funcgraph failed";
159     return ret;
160   }
161   DebugInfoManager manager;
162 
163   auto quant_lite_model = ParseLiteModel(old_graph, param);
164   if (quant_lite_model == nullptr) {
165     MS_LOG(ERROR) << "Parse quant lite model failed";
166     return RET_ERROR;
167   }
168   if (origin_lite_model == nullptr) {
169     MS_LOG(ERROR) << "Origin lite model nullptr.";
170     return RET_ERROR;
171   }
172 
173   ret = manager.CompareOriginWithQuant(origin_model, quant_model, op_parameters, param, origin_lite_model,
174                                        quant_lite_model);
175   auto free_buffer = [&] {
176     for (auto parameter : op_parameters) {
177       if (parameter.second != nullptr) {
178         free(parameter.second);
179         parameter.second = nullptr;
180       }
181     }
182     op_parameters.clear();
183   };
184   if (ret != RET_OK) {
185     MS_LOG(ERROR) << "Compare origin with quant failed.";
186     free_buffer();
187     return ret;
188   }
189   free_buffer();
190   return RET_OK;
191 }
192 
ConvertValueNodeToParameter(const FuncGraphPtr & func_graph)193 int ConvertValueNodeToParameter(const FuncGraphPtr &func_graph) {
194   auto cnodes = func_graph->GetOrderedCnodes();
195   for (auto &cnode : cnodes) {
196     for (size_t i = kPrimOffset; i < cnode->size(); ++i) {
197       auto input = cnode->input(i);
198       if (!input->isa<ValueNode>()) {
199         continue;
200       }
201       auto tensor_info = input->cast<ValueNodePtr>()->value()->cast<tensor::TensorPtr>();
202       if (tensor_info == nullptr) {
203         MS_LOG(INFO) << cnode->fullname_with_scope() << " input index: " << i << " cast tensor nullptr.";
204         continue;
205       }
206       auto parameter = func_graph->add_parameter();
207       auto status = InitParameterFromTensorInfo(parameter, tensor_info);
208       if (status != RET_OK) {
209         MS_LOG(ERROR) << "Init parameter From tensor failed, tenor: " << tensor_info->name();
210         return status;
211       }
212       parameter->set_name(input->fullname_with_scope());
213       auto manage = Manage(func_graph);
214       manage->Replace(input, parameter);
215     }
216   }
217   return RET_OK;
218 }
219 
PrepareQuantize(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)220 int QuantizationOptimizer::PrepareQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
221   if (!param->train_model && param->save_type == kMindIR) {
222     auto status = ConvertValueNodeToParameter(old_graph);
223     if (status != RET_OK) {
224       MS_LOG(ERROR) << "Convert value node To parameter failed.";
225       return status;
226     }
227   }
228 
229   auto convert_pm = std::make_shared<opt::LitePassManager>("anf graph convert pass manager", true);
230   convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
231   auto optimizer = std::make_shared<opt::GraphOptimizer>();
232   optimizer->AddPassManager(convert_pm);
233   if (optimizer->Optimize(old_graph) == nullptr) {
234     MS_LOG(ERROR) << "run graph pass failed";
235     return RET_ERROR;
236   }
237 
238   bool per_layer = param->commonQuantParam.quant_type == quant::QUANT_ALL && !param->fullQuantParam.per_channel &&
239                    param->fullQuantParam.target_device != DSP;
240   if (per_layer) {
241     CLEStrategy cle_strategy(old_graph);
242     auto status = cle_strategy.Run();
243     if (status != RET_OK) {
244       MS_LOG(ERROR) << "do cle_strategy failed!";
245       return status;
246     }
247   }
248 
249   if (param->commonQuantParam.quant_type == quant::QUANT_ALL && param->fullQuantParam.bias_correction) {
250     SplitSharedBias split_shared_bias(old_graph, param);
251     if (split_shared_bias.Run() != RET_OK) {
252       MS_LOG(ERROR) << "split shared bias node failed!";
253       return RET_ERROR;
254     }
255   }
256   return RET_OK;
257 }
258 
DoSingleGraphQuantize(const FuncGraphPtr & func_graph,const std::shared_ptr<ConverterPara> & param)259 int QuantizationOptimizer::DoSingleGraphQuantize(const FuncGraphPtr &func_graph,
260                                                  const std::shared_ptr<ConverterPara> &param) {
261   CHECK_NULL_RETURN(param);
262   int status = PrepareQuantize(func_graph, param);
263   if (status != RET_OK) {
264     MS_LOG(ERROR) << "PrepareQuantize failed.";
265     return status;
266   }
267 
268   std::shared_ptr<mindspore::Model> origin;
269   std::shared_ptr<lite::Model> origin_lite_model;
270   if (param->commonQuantParam.is_debug) {  // Bak fp32 model for debug
271     auto quant_type = param->commonQuantParam.quant_type;
272     param->commonQuantParam.quant_type = quant::QUANT_NONE;
273     origin = std::make_shared<mindspore::Model>();
274     CHECK_NULL_RETURN(origin);
275     size_t size = 0;
276     auto ret = BuildModelByFuncGraph(origin, func_graph, param, &size);
277     param->commonQuantParam.quant_type = quant_type;
278     if (ret != kSuccess) {
279       MS_LOG(ERROR) << "Build model failed";
280       return RET_ERROR;
281     }
282     origin_lite_model = ParseLiteModel(func_graph, param);
283     if (origin_lite_model == nullptr) {
284       MS_LOG(ERROR) << "Parse lite model failed.";
285       return RET_ERROR;
286     }
287   }
288   if (param->commonQuantParam.quant_type == quant::QUANT_ALL) {  // Full Quantization
289     status = ConvertFp16ToFp32(func_graph);
290     if (status != RET_OK) {
291       MS_LOG(ERROR) << "Converter fp16 to fp32 failed.";
292       return status;
293     }
294     status = DoFullQuant(func_graph, param);
295     if (status != RET_OK) {
296       MS_LOG(ERROR) << "Do full quant failed.";
297       return status;
298     }
299   } else if (param->commonQuantParam.quant_type == quant::QUANT_WEIGHT) {  // Weight Quantization
300     status = DoWeightQuant(func_graph, param);
301     if (status != RET_OK) {
302       MS_LOG(ERROR) << "Do weight quant failed.";
303       return status;
304     }
305   } else if (param->commonQuantParam.quant_type == quant::QUANT_DYNAMIC) {  // Dynamic Quantization
306     status = DoDynamicQuant(func_graph, param);
307     if (status != RET_OK) {
308       MS_LOG(ERROR) << "Do dynamic quant failed.";
309       return status;
310     }
311   }
312 
313   if (param->fullQuantParam.target_device != ASCEND) {
314     auto optimizer = std::make_shared<opt::GraphOptimizer>();
315     CHECK_NULL_RETURN(optimizer);
316     auto fusion_pm = std::make_shared<opt::LitePassManager>("fusion pass manager after quant", false);
317     CHECK_NULL_RETURN(fusion_pm);
318     fusion_pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>());
319     optimizer->AddPassManager(fusion_pm);
320     if (optimizer->Optimize(func_graph) == nullptr) {
321       MS_LOG(ERROR) << "run cast node fusion failed.";
322       return RET_ERROR;
323     }
324   }
325 
326   if (param->commonQuantParam.is_debug) {
327     status = DoQuantDebug(func_graph, param, origin, origin_lite_model);
328     if (status != RET_OK) {
329       MS_LOG(ERROR) << "Do quant debug failed.";
330       return status;
331     }
332   }
333   return RET_OK;
334 }
335 
Run(const mindspore::FuncGraphPtr & func_graph)336 int QuantizationOptimizer::Run(const mindspore::FuncGraphPtr &func_graph) {
337   if (param_->commonQuantParam.quant_type == quant::QUANT_NONE || param_->fullQuantParam.target_device == ASCEND) {
338     return RET_OK;
339   }
340   // set manager
341   if (func_graph->manager() == nullptr) {
342     auto root_func_manager = Manage(func_graph);
343     std::set<FuncGraphPtr> all_func_graphs = {};
344     lite::GetAllFuncGraph(func_graph, &all_func_graphs);
345     for (auto graph : all_func_graphs) {
346       graph->set_manager(root_func_manager);
347     }
348   }
349 
350   std::set<FuncGraphPtr> all_func_graphs{};
351   quant::GetFuncGraphs(func_graph, &all_func_graphs);
352   // Support for multi-subgraph models
353   for (auto &item : all_func_graphs) {
354     auto status = DoSingleGraphQuantize(item, param_);
355     if (status != RET_OK) {
356       MS_LOG(ERROR) << "Do Quantize failed.";
357       return status;
358     }
359   }
360   if (param_->fullQuantParam.target_device != ASCEND) {
361     auto optimizer = std::make_shared<opt::GraphOptimizer>();
362     CHECK_NULL_RETURN(optimizer);
363     auto fusion_pm = std::make_shared<opt::LitePassManager>("fusion pass manager after quant", false);
364     CHECK_NULL_RETURN(fusion_pm);
365     fusion_pm->AddPass(std::make_shared<opt::InferShapePass>(param_->fmk_type, param_->train_model));
366     optimizer->AddPassManager(fusion_pm);
367     if (optimizer->Optimize(func_graph) == nullptr) {
368       MS_LOG(ERROR) << "run infershape failed.";
369       return RET_ERROR;
370     }
371   }
372   return RET_OK;
373 }
374 }  // namespace mindspore::lite::quant
375