• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "tools/converter/anf_transform.h"
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <deque>
22 #include "nnacl/op_base.h"
23 #include "src/common/log_adapter.h"
24 #include "tools/converter/optimizer_manager.h"
25 #include "tools/optimizer/common/gllo_utils.h"
26 #include "ir/primitive.h"
27 #include "tools/optimizer/fusion/affine_activation_fusion.h"
28 #include "tools/optimizer/fusion/affine_fusion.h"
29 #include "tools/optimizer/fusion/conv_biasadd_fusion.h"
30 #include "tools/optimizer/fusion/conv_activation_fusion.h"
31 #include "tools/optimizer/fusion/conv_tuple_activation_fusion.h"
32 #include "tools/optimizer/fusion/conv_scale_fusion.h"
33 #include "tools/optimizer/fusion/conv_bn_fusion.h"
34 #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h"
35 #include "tools/optimizer/fusion/constant_folding_fusion.h"
36 #include "tools/optimizer/fusion/norm_fusion.h"
37 #include "tools/optimizer/fusion/batchmatmul_fusion.h"
38 #include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
39 #include "tools/optimizer/fusion/conv_conv_fusion.h"
40 #include "tools/optimizer/fusion/conv_pad_fusion.h"
41 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
42 #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
43 #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
44 #include "tools/optimizer/fusion/multi_head_attention_fusion.h"
45 #include "tools/optimizer/fusion/glu_fusion.h"
46 #include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h"
47 #include "tools/optimizer/fusion/matmul_add_fusion.h"
48 #include "tools/optimizer/fusion/tf_gelu_fusion.h"
49 #include "tools/optimizer/fusion/onnx_gelu_fusion.h"
50 #include "tools/optimizer/fusion/squeeze_fusion.h"
51 #include "tools/optimizer/fusion/reshape_reshape_fusion.h"
52 #include "tools/optimizer/graph/add_tensor_array.h"
53 #include "tools/optimizer/graph/redundant_op_remove_pass.h"
54 #include "tools/optimizer/graph/clip_convert_activation_pass.h"
55 #include "tools/optimizer/graph/update_conv2d_param_pass.h"
56 #include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
57 #include "tools/optimizer/graph/infershape_pass.h"
58 #include "tools/optimizer/graph/slice_prepose_pass.h"
59 #include "tools/optimizer/graph/control_flow_pass.h"
60 #include "tools/optimizer/graph/reduce_same_act_pass.h"
61 #include "tools/optimizer/graph/split_one_pass.h"
62 #include "tools/optimizer/graph/decrease_transpose_algo.h"
63 #include "tools/optimizer/graph/specify_graph_input_format.h"
64 #include "tools/optimizer/graph/dump_graph.h"
65 #include "tools/converter/quantizer/full_quant_quantizer.h"
66 #include "tools/converter/quantizer/quant_cast.h"
67 #include "tools/converter/quantizer/weight_quantizer.h"
68 #include "tools/optimizer/parallel/split_strategy.h"
69 #include "tools/optimizer/parallel/spliter.h"
70 #include "tools/optimizer/fisson/iter_node_outputs.h"
71 #include "tools/optimizer/fisson/node_out_shapes.h"
72 #include "tools/optimizer/parallel/parallel_pass.h"
73 #include "include/registry/pass_registry.h"
74 #include "tools/optimizer/fisson/multi_conv_split_pass.h"
75 #include "tools/optimizer/fusion/transpose_fusion.h"
76 #include "tools/optimizer/format/to_nchw_format.h"
77 #include "tools/optimizer/format/to_nhwc_format.h"
78 #include "tools/converter/acl/acl_pass.h"
79 
80 using std::string;
81 namespace mindspore::lite {
82 AnfTransform::AnfTransform() = default;
83 
84 AnfTransform::~AnfTransform() = default;
85 
MarkTrainInputOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)86 STATUS AnfTransform::MarkTrainInputOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
87   for (size_t i = 1; i < cnode->inputs().size(); i++) {
88     auto input_node = cnode->input(i);
89     if (!utils::isa<CNodePtr>(input_node)) {
90       continue;
91     }
92     auto input_cnode = utils::cast<CNodePtr>(input_node);
93     MS_CHECK_TRUE_RET(input_cnode != nullptr, RET_ERROR);
94     auto prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
95     if (prim == nullptr) {
96       MS_LOG(DEBUG) << "Primitive is nullptr.";
97       continue;
98     }
99     prim->AddAttr("trainOp", MakeValue(true));
100   }
101   return RET_OK;
102 }
103 
MarkTrainWeightSharingOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)104 STATUS AnfTransform::MarkTrainWeightSharingOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
105   auto node_list = TopoSort(func_graph->get_return());
106   for (auto &node : node_list) {
107     if (!utils::isa<CNodePtr>(node)) {
108       continue;
109     }
110     auto graph_cnode = utils::cast<CNodePtr>(node);
111     MS_CHECK_TRUE_RET(graph_cnode != nullptr, RET_ERROR);
112     auto graph_prim = GetValueNode<PrimitivePtr>(graph_cnode->input(0));
113     if (graph_prim == nullptr) {
114       MS_LOG(DEBUG) << "Primitive is nullptr.";
115       continue;
116     }
117     for (size_t i = 1; i < graph_cnode->inputs().size(); i++) {
118       for (size_t j = 1; j < cnode->inputs().size(); j++) {
119         if ((graph_cnode->input(i) == cnode->input(j)) && utils::isa<Parameter>(cnode->input(j))) {
120           graph_prim->AddAttr("trainOp", MakeValue(true));
121         }
122       }
123     }
124   }
125   return RET_OK;
126 }
127 
MarkTrainOp(const FuncGraphPtr & func_graph)128 STATUS AnfTransform::MarkTrainOp(const FuncGraphPtr &func_graph) {
129   auto node_list = TopoSort(func_graph->get_return());
130   for (auto &node : node_list) {
131     if (!utils::isa<CNodePtr>(node)) {
132       continue;
133     }
134     auto cnode = utils::cast<CNodePtr>(node);
135     MS_CHECK_TRUE_RET(cnode != nullptr, RET_ERROR);
136     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
137     if (prim == nullptr) {
138       MS_LOG(DEBUG) << "Primitive is nullptr.";
139       continue;
140     }
141     if (opt::IsTrainOp(cnode)) {
142       prim->AddAttr("trainOp", MakeValue(true));
143       auto status = MarkTrainInputOp(func_graph, cnode);
144       if (status != RET_OK) {
145         MS_LOG(ERROR) << "MarkTrainInputOp failed.";
146         return RET_ERROR;
147       }
148       status = MarkTrainWeightSharingOp(func_graph, cnode);
149       if (status != RET_OK) {
150         MS_LOG(ERROR) << "MarkTrainWeightSharingOp failed.";
151         return RET_ERROR;
152       }
153     }
154   }
155   return RET_OK;
156 }
157 
RunFusionPass(const FuncGraphPtr & old_graph,const converter::Flags * config)158 int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
159   auto status = MarkTrainOp(old_graph);
160   if (status != RET_OK) {
161     MS_LOG(ERROR) << "MarkTrainOp failed.";
162     return RET_ERROR;
163   }
164   CHECK_NULL_RETURN(config);
165   auto optimizer = std::make_shared<opt::GraphOptimizer>();
166   CHECK_NULL_RETURN(optimizer);
167   auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
168   CHECK_NULL_RETURN(fusion_pm);
169 
170   // The training model only does the fusion of the inference part
171   // remove quantdtype when awaretraining
172   fusion_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
173   fusion_pm->AddPass(std::make_shared<opt::TransposeFusion>());
174   fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>());
175   fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
176   fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>(config->fmk));
177   fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>(config->fmk));
178   fusion_pm->AddPass(std::make_shared<opt::TfNormFusion>());
179   fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion>());
180   fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
181   fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
182   fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>());
183   fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>());
184   fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>());
185   fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>());
186   fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>());
187   fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>());
188   fusion_pm->AddPass(std::make_shared<opt::TfGeLUFusion>());
189   fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>());
190   fusion_pm->AddPass(std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>());
191   fusion_pm->AddPass(std::make_shared<opt::GLUFusion>());
192   fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
193   fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
194   fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
195   if (config->fmk == converter::kFmkTypeMs && !config->trainModel) {
196     auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
197     if (remove_unused_cast_pass == nullptr) {
198       MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified";
199       return RET_ERROR;
200     }
201     remove_unused_cast_pass->SetFmkType(config->fmk);
202     fusion_pm->AddPass(remove_unused_cast_pass);
203   }
204   fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
205   fusion_pm->AddPass(std::make_shared<opt::ConvPadFusion>());
206   fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
207   optimizer->AddPassManager(fusion_pm);
208   if (optimizer->Optimize(old_graph) == nullptr) {
209     MS_LOG(ERROR) << "run op fusion failed.";
210     return RET_ERROR;
211   }
212   return RET_OK;
213 }
214 
RunParallelPass(const FuncGraphPtr & old_graph,const converter::Flags * config)215 int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
216   CHECK_NULL_RETURN(old_graph);
217   CHECK_NULL_RETURN(config);
218   MS_LOG(DEBUG) << "Run ParallelPass start";
219   if (config->trainModel || config->parallel_split_config_.parallel_split_type_ == converter::SplitNo) {
220     return RET_OK;
221   }
222   if (config->parallel_split_config_.parallel_split_type_ == converter::SplitByUserRatio) {
223     auto optimizer = std::make_shared<opt::GraphOptimizer>();
224     CHECK_NULL_RETURN(optimizer);
225     auto graph_inputs = old_graph->get_inputs();
226     opt::SplitMode split_mode = opt::NoSplit;
227     for (const auto &graph_input : graph_inputs) {
228       if (utils::isa<Parameter>(graph_input)) {
229         auto input_parameter = dyn_cast<Parameter>(graph_input);
230         MSLITE_CHECK_PTR(input_parameter->Shape());
231         auto shape_ptr = input_parameter->Shape()->cast<abstract::ShapePtr>();
232         MSLITE_CHECK_PTR(shape_ptr);
233         auto batch = shape_ptr->shape().front();
234         if (batch > opt::kDefaultBatch) {
235           split_mode = opt::SplitN;
236         } else {
237           split_mode = opt::SplitH;
238         }
239         break;
240       }
241     }
242     // 1. deal with split strategy
243     std::unordered_map<std::string, opt::SplitStrategy> split_strategys =
244       opt::ParserSplitStrategy(config->parallel_split_config_.parallel_compute_rates_,
245                                config->parallel_split_config_.parallel_devices_, split_mode);
246     if (split_strategys.empty()) {
247       MS_LOG(ERROR) << "parse split_strategy error.";
248       return RET_OK;
249     }
250     opt::Spliter::GetInstance()->RecordGraphInfo(old_graph);
251     auto parallel_pm = std::make_shared<opt::PassManager>("anf parallel pass manager", true);
252     CHECK_NULL_RETURN(parallel_pm);
253     // 2. preceding parallel pass
254     parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
255     parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
256     std::set<int, opt::IntCompare> match_multi_numbers = opt::Spliter::GetInstance()->graph_match_multi_numbers();
257     int max_match_number = *match_multi_numbers.begin();
258     // we do not deal with single conv node
259     for (int match_number = max_match_number; match_number > opt::kDefaultBatch; --match_number) {
260       // 3. multi_conv parallel pass
261       parallel_pm->AddPass(std::make_shared<opt::MultiConvSplitPass>(split_strategys, config->fmk, match_number));
262       parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
263       parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
264     }
265     optimizer->AddPassManager(parallel_pm);
266     if (optimizer->Optimize(old_graph) == nullptr) {
267       MS_LOG(ERROR) << "run const fold failed.";
268       return RET_ERROR;
269     }
270   }
271   MS_LOG(DEBUG) << "Run ParallelPass end";
272   return RET_OK;
273 }
274 
RunGraphPass(const FuncGraphPtr & old_graph,const converter::Flags * config)275 int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
276   CHECK_NULL_RETURN(old_graph);
277   CHECK_NULL_RETURN(config);
278   auto optimizer = std::make_shared<opt::GraphOptimizer>();
279   CHECK_NULL_RETURN(optimizer);
280   auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
281   CHECK_NULL_RETURN(graph_pm);
282   if (config->fmk == converter::kFmkTypeTflite || config->fmk == converter::kFmkTypeTf ||
283       config->fmk == converter::kFmkTypeOnnx) {
284     graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
285   }
286   auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
287   CHECK_NULL_RETURN(slice_prepose_pass);
288   slice_prepose_pass->SetFmkType(config->fmk);
289   graph_pm->AddPass(slice_prepose_pass);
290   graph_pm->AddPass(std::make_shared<opt::AddTensorArray>());
291   optimizer->AddPassManager(graph_pm);
292   if (optimizer->Optimize(old_graph) == nullptr) {
293     MS_LOG(ERROR) << "run  graph pass failed.";
294     return RET_ERROR;
295   }
296   return RET_OK;
297 }
298 
RunConvertPass(const FuncGraphPtr & old_graph,const converter::Flags * config)299 int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
300 #ifdef ENABLE_LITE_ACL
301   auto acl_pass = std::make_shared<opt::AclPass>(config->fmk);
302   if (!acl_pass->Run(old_graph)) {
303     MS_LOG(ERROR) << "Acl pass failed.";
304     return RET_ERROR;
305   }
306 #endif
307   auto optimizer = std::make_shared<opt::GraphOptimizer>();
308   CHECK_NULL_RETURN(optimizer);
309   auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
310   CHECK_NULL_RETURN(convert_pm);
311   convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(config->trainModel));
312   auto infershape_pass = std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel);
313   CHECK_NULL_RETURN(infershape_pass);
314   convert_pm->AddPass(infershape_pass);
315   auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>();
316   convert_pm->AddPass(update_conv2d_param_pass);
317   optimizer->AddPassManager(convert_pm);
318   if (optimizer->Optimize(old_graph) == nullptr) {
319     MS_LOG(ERROR) << "run graph convert pass failed.";
320     return RET_ERROR;
321   }
322   return RET_OK;
323 }
324 
RunConstFoldPass(const FuncGraphPtr & old_graph,const converter::Flags * config)325 int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
326   CHECK_NULL_RETURN(config);
327   auto optimizer = std::make_shared<opt::GraphOptimizer>();
328   auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
329   CHECK_NULL_RETURN(optimizer);
330   CHECK_NULL_RETURN(const_fold_pm);
331   if (!config->trainModel) {
332     const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
333   }
334   const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel));
335   const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
336   const_fold_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
337   optimizer->AddPassManager(const_fold_pm);
338   if (optimizer->Optimize(old_graph) == nullptr) {
339     MS_LOG(ERROR) << "run const fold failed.";
340     return RET_ERROR;
341   }
342   return RET_OK;
343 }
344 
GetFuncGraphs(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * all_func_graphs)345 void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
346   MS_ASSERT(func_graph != nullptr);
347   MS_ASSERT(all_func_graphs != nullptr);
348   all_func_graphs->insert(func_graph);
349   auto nodes = func_graph->GetOrderedCnodes();
350   std::deque<CNodePtr> to_process{};
351   to_process.insert(to_process.end(), nodes.begin(), nodes.end());
352   while (!to_process.empty()) {
353     auto &cur_cnode = to_process.front();
354     for (auto &input : cur_cnode->inputs()) {
355       if (!IsValueNode<FuncGraph>(input)) {
356         continue;
357       }
358       auto new_fg = GetValueNode<FuncGraphPtr>(input);
359       if (all_func_graphs->find(new_fg) != all_func_graphs->end()) {
360         continue;
361       }
362       all_func_graphs->insert(new_fg);
363       auto new_nodes = new_fg->GetOrderedCnodes();
364       to_process.insert(to_process.end(), new_nodes.begin(), new_nodes.end());
365     }
366     to_process.pop_front();
367   }
368 }
369 
DoSingleGraphQuantize(const FuncGraphPtr & old_graph,const converter::Flags * config)370 int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
371   // quant
372   if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
373     this->m_quantizer_ = std::make_unique<quant::FullQuantQuantizer>(old_graph, config->commonQuantParam.bit_num);
374     if (m_quantizer_ == nullptr) {
375       MS_LOG(ERROR) << "New FullQuantQuantizer failed";
376       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
377       return RET_ERROR;
378     }
379   } else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) {
380     this->m_quantizer_ = std::make_unique<quant::WeightQuantizer>(old_graph, *config);
381     if (m_quantizer_ == nullptr) {
382       MS_LOG(ERROR) << "New WeightQuantizer failed";
383       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
384       return RET_ERROR;
385     }
386   }
387   if (m_quantizer_ != nullptr) {
388     m_quantizer_->flags = *config;
389     auto status = m_quantizer_->DoQuantize(old_graph);
390     if (status != RET_OK) {
391       MS_LOG(ERROR) << "DoQuantization failed " << status;
392       ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
393       return RET_ERROR;
394     }
395   }
396   return RET_OK;
397 }
398 
DoQuantize(const FuncGraphPtr & old_graph,const converter::Flags * config)399 int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
400   std::set<FuncGraphPtr> all_func_graphs{};
401   GetFuncGraphs(old_graph, &all_func_graphs);
402   for (auto &item : all_func_graphs) {
403     auto status = DoSingleGraphQuantize(item, config);
404     if (status != RET_OK) {
405       MS_LOG(ERROR) << "Do Quantize failed.";
406       return status;
407     }
408   }
409   return RET_OK;
410 }
411 
TransformFuncGraph(const FuncGraphPtr & old_graph,const converter::Flags * config)412 FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
413   MS_ASSERT(old_graph != nullptr);
414   if (config == nullptr) {
415     MS_LOG(ERROR) << "config should be specified";
416     return nullptr;
417   }
418 
419   auto status = RunConvertPass(old_graph, config);
420   if (status != RET_OK) {
421     MS_LOG(ERROR) << "Run convert pass failed.";
422     return nullptr;
423   }
424 
425   if (!RunExternalPass(old_graph, registry::POSITION_BEGIN)) {
426     MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
427     return nullptr;
428   }
429 
430   status = RunConstFoldPass(old_graph, config);
431   if (status != RET_OK) {
432     MS_LOG(ERROR) << "Run const fold pass failed.";
433     return nullptr;
434   }
435 
436   if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
437     MS_LOG(WARNING) << "Run infershape opt pass failed.";
438   } else {
439     if (!RunOptimizerPass(old_graph, {"DecreaseTransposeAlgo"})) {
440       MS_LOG(ERROR) << "Run transpose opt pass failed.";
441       return nullptr;
442     }
443   }
444 
445   auto reduce_act_pass = std::make_shared<opt::ReduceSameActPass>();
446   MS_CHECK_TRUE_RET(reduce_act_pass != nullptr, nullptr);
447   if (!reduce_act_pass->Run(old_graph)) {
448     MS_LOG(ERROR) << "Run reduce same act pass failed.";
449     return nullptr;
450   }
451 
452   auto split_one_pass = std::make_shared<opt::SplitOnePass>();
453   MS_CHECK_TRUE_RET(split_one_pass != nullptr, nullptr);
454   if (!split_one_pass->Run(old_graph)) {
455     MS_LOG(ERROR) << "Run split one pass failed.";
456     return nullptr;
457   }
458 
459   if (!config->disableFusion) {
460     status = RunFusionPass(old_graph, config);
461     if (status != RET_OK) {
462       MS_LOG(ERROR) << "Run fusion pass failed.";
463       return nullptr;
464     }
465   }
466 
467   if (!RunExternalPass(old_graph, registry::POSITION_END)) {
468     MS_LOG(ERROR) << "Run external pass failed, place is END";
469     return nullptr;
470   }
471 
472   if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
473     MS_LOG(WARNING) << "Run infershape opt pass failed.";
474     if (!RunOptimizerPass(old_graph, {"SpecifyGraphInputFormat"})) {
475       MS_LOG(ERROR) << "specify the input format of exported model failed.";
476       return nullptr;
477     }
478   } else {
479     if (!RunOptimizerPass(old_graph, {"SpecifyGraphInputFormat", "DecreaseTransposeAlgo"})) {
480       MS_LOG(ERROR) << "Run transpose opt pass failed.";
481       return nullptr;
482     }
483   }
484 
485   status = RunGraphPass(old_graph, config);
486   if (status != RET_OK) {
487     MS_LOG(ERROR) << "Run convert pass failed.";
488     return nullptr;
489   }
490 
491   status = RunParallelPass(old_graph, config);
492   if (status != RET_OK) {
493     MS_LOG(ERROR) << "Run convert pass failed.";
494     return nullptr;
495   }
496 
497   status = DoQuantize(old_graph, config);
498   if (status != RET_OK) {
499     MS_LOG(ERROR) << "Do Quantize failed.";
500     return nullptr;
501   }
502 
503   return old_graph;
504 }
505 
StoreBuiltinPass(const converter::Flags * config)506 bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) {
507   if (config == nullptr) {
508     MS_LOG(ERROR) << "config is nullptr";
509     return false;
510   }
511   auto fmk = config->fmk;
512   auto is_train = config->trainModel;
513   std::unordered_map<std::string, opt::PassPtr> passes = {
514     {"DumpGraph", std::make_shared<opt::DumpGraph>(config)},
515     {"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)},
516     {"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)},
517     {"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)},
518     {"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train)},
519     {"SpecifyGraphInputFormat", std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat)}};
520   bool succeed_store = true;
521   for (auto iter = passes.begin(); iter != passes.end(); ++iter) {
522     if (PassStorage::StorePass(iter->first, iter->second) != RET_OK) {
523       MS_LOG(ERROR) << "external pass name conflicts with that of internal pass, the pass name is " << iter->first
524                     << ", please edit external pass name.";
525       succeed_store = false;
526     }
527   }
528   return succeed_store;
529 }
530 
Transform(const FuncGraphPtr & main_graph,const converter::Flags * config)531 FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
532   if (!StoreBuiltinPass(config)) {
533     MS_LOG(ERROR) << "store pass failed.";
534     return nullptr;
535   }
536   auto new_graph = TransformFuncGraph(main_graph, config);
537   if (new_graph == nullptr) {
538     MS_LOG(ERROR) << "optimizer failed.";
539     ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
540   }
541   return new_graph;
542 }
543 }  // namespace mindspore::lite
544