1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "tools/converter/anf_transform.h"
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <deque>
22 #include "nnacl/op_base.h"
23 #include "src/common/log_adapter.h"
24 #include "tools/converter/optimizer_manager.h"
25 #include "tools/optimizer/common/gllo_utils.h"
26 #include "ir/primitive.h"
27 #include "tools/optimizer/fusion/affine_activation_fusion.h"
28 #include "tools/optimizer/fusion/affine_fusion.h"
29 #include "tools/optimizer/fusion/conv_biasadd_fusion.h"
30 #include "tools/optimizer/fusion/conv_activation_fusion.h"
31 #include "tools/optimizer/fusion/conv_tuple_activation_fusion.h"
32 #include "tools/optimizer/fusion/conv_scale_fusion.h"
33 #include "tools/optimizer/fusion/conv_bn_fusion.h"
34 #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h"
35 #include "tools/optimizer/fusion/constant_folding_fusion.h"
36 #include "tools/optimizer/fusion/norm_fusion.h"
37 #include "tools/optimizer/fusion/batchmatmul_fusion.h"
38 #include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
39 #include "tools/optimizer/fusion/conv_conv_fusion.h"
40 #include "tools/optimizer/fusion/conv_pad_fusion.h"
41 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
42 #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
43 #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
44 #include "tools/optimizer/fusion/multi_head_attention_fusion.h"
45 #include "tools/optimizer/fusion/glu_fusion.h"
46 #include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h"
47 #include "tools/optimizer/fusion/matmul_add_fusion.h"
48 #include "tools/optimizer/fusion/tf_gelu_fusion.h"
49 #include "tools/optimizer/fusion/onnx_gelu_fusion.h"
50 #include "tools/optimizer/fusion/squeeze_fusion.h"
51 #include "tools/optimizer/fusion/reshape_reshape_fusion.h"
52 #include "tools/optimizer/graph/add_tensor_array.h"
53 #include "tools/optimizer/graph/redundant_op_remove_pass.h"
54 #include "tools/optimizer/graph/clip_convert_activation_pass.h"
55 #include "tools/optimizer/graph/update_conv2d_param_pass.h"
56 #include "tools/optimizer/graph/unused_cast_node_remove_pass.h"
57 #include "tools/optimizer/graph/infershape_pass.h"
58 #include "tools/optimizer/graph/slice_prepose_pass.h"
59 #include "tools/optimizer/graph/control_flow_pass.h"
60 #include "tools/optimizer/graph/reduce_same_act_pass.h"
61 #include "tools/optimizer/graph/split_one_pass.h"
62 #include "tools/optimizer/graph/decrease_transpose_algo.h"
63 #include "tools/optimizer/graph/specify_graph_input_format.h"
64 #include "tools/optimizer/graph/dump_graph.h"
65 #include "tools/converter/quantizer/full_quant_quantizer.h"
66 #include "tools/converter/quantizer/quant_cast.h"
67 #include "tools/converter/quantizer/weight_quantizer.h"
68 #include "tools/optimizer/parallel/split_strategy.h"
69 #include "tools/optimizer/parallel/spliter.h"
70 #include "tools/optimizer/fisson/iter_node_outputs.h"
71 #include "tools/optimizer/fisson/node_out_shapes.h"
72 #include "tools/optimizer/parallel/parallel_pass.h"
73 #include "include/registry/pass_registry.h"
74 #include "tools/optimizer/fisson/multi_conv_split_pass.h"
75 #include "tools/optimizer/fusion/transpose_fusion.h"
76 #include "tools/optimizer/format/to_nchw_format.h"
77 #include "tools/optimizer/format/to_nhwc_format.h"
78 #include "tools/converter/acl/acl_pass.h"
79
80 using std::string;
81 namespace mindspore::lite {
82 AnfTransform::AnfTransform() = default;
83
84 AnfTransform::~AnfTransform() = default;
85
MarkTrainInputOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)86 STATUS AnfTransform::MarkTrainInputOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
87 for (size_t i = 1; i < cnode->inputs().size(); i++) {
88 auto input_node = cnode->input(i);
89 if (!utils::isa<CNodePtr>(input_node)) {
90 continue;
91 }
92 auto input_cnode = utils::cast<CNodePtr>(input_node);
93 MS_CHECK_TRUE_RET(input_cnode != nullptr, RET_ERROR);
94 auto prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
95 if (prim == nullptr) {
96 MS_LOG(DEBUG) << "Primitive is nullptr.";
97 continue;
98 }
99 prim->AddAttr("trainOp", MakeValue(true));
100 }
101 return RET_OK;
102 }
103
MarkTrainWeightSharingOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)104 STATUS AnfTransform::MarkTrainWeightSharingOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
105 auto node_list = TopoSort(func_graph->get_return());
106 for (auto &node : node_list) {
107 if (!utils::isa<CNodePtr>(node)) {
108 continue;
109 }
110 auto graph_cnode = utils::cast<CNodePtr>(node);
111 MS_CHECK_TRUE_RET(graph_cnode != nullptr, RET_ERROR);
112 auto graph_prim = GetValueNode<PrimitivePtr>(graph_cnode->input(0));
113 if (graph_prim == nullptr) {
114 MS_LOG(DEBUG) << "Primitive is nullptr.";
115 continue;
116 }
117 for (size_t i = 1; i < graph_cnode->inputs().size(); i++) {
118 for (size_t j = 1; j < cnode->inputs().size(); j++) {
119 if ((graph_cnode->input(i) == cnode->input(j)) && utils::isa<Parameter>(cnode->input(j))) {
120 graph_prim->AddAttr("trainOp", MakeValue(true));
121 }
122 }
123 }
124 }
125 return RET_OK;
126 }
127
MarkTrainOp(const FuncGraphPtr & func_graph)128 STATUS AnfTransform::MarkTrainOp(const FuncGraphPtr &func_graph) {
129 auto node_list = TopoSort(func_graph->get_return());
130 for (auto &node : node_list) {
131 if (!utils::isa<CNodePtr>(node)) {
132 continue;
133 }
134 auto cnode = utils::cast<CNodePtr>(node);
135 MS_CHECK_TRUE_RET(cnode != nullptr, RET_ERROR);
136 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
137 if (prim == nullptr) {
138 MS_LOG(DEBUG) << "Primitive is nullptr.";
139 continue;
140 }
141 if (opt::IsTrainOp(cnode)) {
142 prim->AddAttr("trainOp", MakeValue(true));
143 auto status = MarkTrainInputOp(func_graph, cnode);
144 if (status != RET_OK) {
145 MS_LOG(ERROR) << "MarkTrainInputOp failed.";
146 return RET_ERROR;
147 }
148 status = MarkTrainWeightSharingOp(func_graph, cnode);
149 if (status != RET_OK) {
150 MS_LOG(ERROR) << "MarkTrainWeightSharingOp failed.";
151 return RET_ERROR;
152 }
153 }
154 }
155 return RET_OK;
156 }
157
RunFusionPass(const FuncGraphPtr & old_graph,const converter::Flags * config)158 int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
159 auto status = MarkTrainOp(old_graph);
160 if (status != RET_OK) {
161 MS_LOG(ERROR) << "MarkTrainOp failed.";
162 return RET_ERROR;
163 }
164 CHECK_NULL_RETURN(config);
165 auto optimizer = std::make_shared<opt::GraphOptimizer>();
166 CHECK_NULL_RETURN(optimizer);
167 auto fusion_pm = std::make_shared<opt::PassManager>("anf fusion pass manager", false);
168 CHECK_NULL_RETURN(fusion_pm);
169
170 // The training model only does the fusion of the inference part
171 // remove quantdtype when awaretraining
172 fusion_pm->AddPass(std::make_shared<opt::SqueezeFusion>());
173 fusion_pm->AddPass(std::make_shared<opt::TransposeFusion>());
174 fusion_pm->AddPass(std::make_shared<opt::ReshapeReshapeFusion>());
175 fusion_pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
176 fusion_pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>(config->fmk));
177 fusion_pm->AddPass(std::make_shared<opt::ConvScaleFusion>(config->fmk));
178 fusion_pm->AddPass(std::make_shared<opt::TfNormFusion>());
179 fusion_pm->AddPass(std::make_shared<opt::OnnxLayerNormFusion>());
180 fusion_pm->AddPass(std::make_shared<opt::BatchMatMulFusion>());
181 fusion_pm->AddPass(std::make_shared<opt::SigmoidMulFusion>());
182 fusion_pm->AddPass(std::make_shared<opt::ConvActivationFusion>());
183 fusion_pm->AddPass(std::make_shared<opt::ConvTupleGetItemFusion>());
184 fusion_pm->AddPass(std::make_shared<opt::ConvTupleActivationFusion>());
185 fusion_pm->AddPass(std::make_shared<opt::TfliteLstmCellFusion>());
186 fusion_pm->AddPass(std::make_shared<opt::TfLstmCellFusion>());
187 fusion_pm->AddPass(std::make_shared<opt::TfBidirectionGruFusion>());
188 fusion_pm->AddPass(std::make_shared<opt::TfGeLUFusion>());
189 fusion_pm->AddPass(std::make_shared<opt::OnnxGeLUFusion>());
190 fusion_pm->AddPass(std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>());
191 fusion_pm->AddPass(std::make_shared<opt::GLUFusion>());
192 fusion_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
193 fusion_pm->AddPass(std::make_shared<opt::AffineFusion>());
194 fusion_pm->AddPass(std::make_shared<opt::AffineActivationFusion>());
195 if (config->fmk == converter::kFmkTypeMs && !config->trainModel) {
196 auto remove_unused_cast_pass = std::make_shared<opt::RemoveUnusedCastOpPass>();
197 if (remove_unused_cast_pass == nullptr) {
198 MS_LOG(ERROR) << "RemoveUnusedCastOpPass should be specified";
199 return RET_ERROR;
200 }
201 remove_unused_cast_pass->SetFmkType(config->fmk);
202 fusion_pm->AddPass(remove_unused_cast_pass);
203 }
204 fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
205 fusion_pm->AddPass(std::make_shared<opt::ConvPadFusion>());
206 fusion_pm->AddPass(std::make_shared<opt::MatMulAddFusion>());
207 optimizer->AddPassManager(fusion_pm);
208 if (optimizer->Optimize(old_graph) == nullptr) {
209 MS_LOG(ERROR) << "run op fusion failed.";
210 return RET_ERROR;
211 }
212 return RET_OK;
213 }
214
RunParallelPass(const FuncGraphPtr & old_graph,const converter::Flags * config)215 int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
216 CHECK_NULL_RETURN(old_graph);
217 CHECK_NULL_RETURN(config);
218 MS_LOG(DEBUG) << "Run ParallelPass start";
219 if (config->trainModel || config->parallel_split_config_.parallel_split_type_ == converter::SplitNo) {
220 return RET_OK;
221 }
222 if (config->parallel_split_config_.parallel_split_type_ == converter::SplitByUserRatio) {
223 auto optimizer = std::make_shared<opt::GraphOptimizer>();
224 CHECK_NULL_RETURN(optimizer);
225 auto graph_inputs = old_graph->get_inputs();
226 opt::SplitMode split_mode = opt::NoSplit;
227 for (const auto &graph_input : graph_inputs) {
228 if (utils::isa<Parameter>(graph_input)) {
229 auto input_parameter = dyn_cast<Parameter>(graph_input);
230 MSLITE_CHECK_PTR(input_parameter->Shape());
231 auto shape_ptr = input_parameter->Shape()->cast<abstract::ShapePtr>();
232 MSLITE_CHECK_PTR(shape_ptr);
233 auto batch = shape_ptr->shape().front();
234 if (batch > opt::kDefaultBatch) {
235 split_mode = opt::SplitN;
236 } else {
237 split_mode = opt::SplitH;
238 }
239 break;
240 }
241 }
242 // 1. deal with split strategy
243 std::unordered_map<std::string, opt::SplitStrategy> split_strategys =
244 opt::ParserSplitStrategy(config->parallel_split_config_.parallel_compute_rates_,
245 config->parallel_split_config_.parallel_devices_, split_mode);
246 if (split_strategys.empty()) {
247 MS_LOG(ERROR) << "parse split_strategy error.";
248 return RET_OK;
249 }
250 opt::Spliter::GetInstance()->RecordGraphInfo(old_graph);
251 auto parallel_pm = std::make_shared<opt::PassManager>("anf parallel pass manager", true);
252 CHECK_NULL_RETURN(parallel_pm);
253 // 2. preceding parallel pass
254 parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
255 parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
256 std::set<int, opt::IntCompare> match_multi_numbers = opt::Spliter::GetInstance()->graph_match_multi_numbers();
257 int max_match_number = *match_multi_numbers.begin();
258 // we do not deal with single conv node
259 for (int match_number = max_match_number; match_number > opt::kDefaultBatch; --match_number) {
260 // 3. multi_conv parallel pass
261 parallel_pm->AddPass(std::make_shared<opt::MultiConvSplitPass>(split_strategys, config->fmk, match_number));
262 parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
263 parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
264 }
265 optimizer->AddPassManager(parallel_pm);
266 if (optimizer->Optimize(old_graph) == nullptr) {
267 MS_LOG(ERROR) << "run const fold failed.";
268 return RET_ERROR;
269 }
270 }
271 MS_LOG(DEBUG) << "Run ParallelPass end";
272 return RET_OK;
273 }
274
RunGraphPass(const FuncGraphPtr & old_graph,const converter::Flags * config)275 int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
276 CHECK_NULL_RETURN(old_graph);
277 CHECK_NULL_RETURN(config);
278 auto optimizer = std::make_shared<opt::GraphOptimizer>();
279 CHECK_NULL_RETURN(optimizer);
280 auto graph_pm = std::make_shared<opt::PassManager>("anf graph pass manager", true);
281 CHECK_NULL_RETURN(graph_pm);
282 if (config->fmk == converter::kFmkTypeTflite || config->fmk == converter::kFmkTypeTf ||
283 config->fmk == converter::kFmkTypeOnnx) {
284 graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
285 }
286 auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
287 CHECK_NULL_RETURN(slice_prepose_pass);
288 slice_prepose_pass->SetFmkType(config->fmk);
289 graph_pm->AddPass(slice_prepose_pass);
290 graph_pm->AddPass(std::make_shared<opt::AddTensorArray>());
291 optimizer->AddPassManager(graph_pm);
292 if (optimizer->Optimize(old_graph) == nullptr) {
293 MS_LOG(ERROR) << "run graph pass failed.";
294 return RET_ERROR;
295 }
296 return RET_OK;
297 }
298
RunConvertPass(const FuncGraphPtr & old_graph,const converter::Flags * config)299 int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
300 #ifdef ENABLE_LITE_ACL
301 auto acl_pass = std::make_shared<opt::AclPass>(config->fmk);
302 if (!acl_pass->Run(old_graph)) {
303 MS_LOG(ERROR) << "Acl pass failed.";
304 return RET_ERROR;
305 }
306 #endif
307 auto optimizer = std::make_shared<opt::GraphOptimizer>();
308 CHECK_NULL_RETURN(optimizer);
309 auto convert_pm = std::make_shared<opt::PassManager>("anf graph convert pass manager", true);
310 CHECK_NULL_RETURN(convert_pm);
311 convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(config->trainModel));
312 auto infershape_pass = std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel);
313 CHECK_NULL_RETURN(infershape_pass);
314 convert_pm->AddPass(infershape_pass);
315 auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>();
316 convert_pm->AddPass(update_conv2d_param_pass);
317 optimizer->AddPassManager(convert_pm);
318 if (optimizer->Optimize(old_graph) == nullptr) {
319 MS_LOG(ERROR) << "run graph convert pass failed.";
320 return RET_ERROR;
321 }
322 return RET_OK;
323 }
324
RunConstFoldPass(const FuncGraphPtr & old_graph,const converter::Flags * config)325 int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const converter::Flags *config) {
326 CHECK_NULL_RETURN(config);
327 auto optimizer = std::make_shared<opt::GraphOptimizer>();
328 auto const_fold_pm = std::make_shared<opt::PassManager>("const fold fusion pass manager", false);
329 CHECK_NULL_RETURN(optimizer);
330 CHECK_NULL_RETURN(const_fold_pm);
331 if (!config->trainModel) {
332 const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(config->fmk));
333 }
334 const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(config->fmk, config->trainModel));
335 const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
336 const_fold_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
337 optimizer->AddPassManager(const_fold_pm);
338 if (optimizer->Optimize(old_graph) == nullptr) {
339 MS_LOG(ERROR) << "run const fold failed.";
340 return RET_ERROR;
341 }
342 return RET_OK;
343 }
344
GetFuncGraphs(const FuncGraphPtr & func_graph,std::set<FuncGraphPtr> * all_func_graphs)345 void AnfTransform::GetFuncGraphs(const FuncGraphPtr &func_graph, std::set<FuncGraphPtr> *all_func_graphs) {
346 MS_ASSERT(func_graph != nullptr);
347 MS_ASSERT(all_func_graphs != nullptr);
348 all_func_graphs->insert(func_graph);
349 auto nodes = func_graph->GetOrderedCnodes();
350 std::deque<CNodePtr> to_process{};
351 to_process.insert(to_process.end(), nodes.begin(), nodes.end());
352 while (!to_process.empty()) {
353 auto &cur_cnode = to_process.front();
354 for (auto &input : cur_cnode->inputs()) {
355 if (!IsValueNode<FuncGraph>(input)) {
356 continue;
357 }
358 auto new_fg = GetValueNode<FuncGraphPtr>(input);
359 if (all_func_graphs->find(new_fg) != all_func_graphs->end()) {
360 continue;
361 }
362 all_func_graphs->insert(new_fg);
363 auto new_nodes = new_fg->GetOrderedCnodes();
364 to_process.insert(to_process.end(), new_nodes.begin(), new_nodes.end());
365 }
366 to_process.pop_front();
367 }
368 }
369
DoSingleGraphQuantize(const FuncGraphPtr & old_graph,const converter::Flags * config)370 int AnfTransform::DoSingleGraphQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
371 // quant
372 if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_ALL) {
373 this->m_quantizer_ = std::make_unique<quant::FullQuantQuantizer>(old_graph, config->commonQuantParam.bit_num);
374 if (m_quantizer_ == nullptr) {
375 MS_LOG(ERROR) << "New FullQuantQuantizer failed";
376 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
377 return RET_ERROR;
378 }
379 } else if (config->commonQuantParam.quant_type == schema::QuantType_QUANT_WEIGHT) {
380 this->m_quantizer_ = std::make_unique<quant::WeightQuantizer>(old_graph, *config);
381 if (m_quantizer_ == nullptr) {
382 MS_LOG(ERROR) << "New WeightQuantizer failed";
383 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_MEMORY_FAILED);
384 return RET_ERROR;
385 }
386 }
387 if (m_quantizer_ != nullptr) {
388 m_quantizer_->flags = *config;
389 auto status = m_quantizer_->DoQuantize(old_graph);
390 if (status != RET_OK) {
391 MS_LOG(ERROR) << "DoQuantization failed " << status;
392 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
393 return RET_ERROR;
394 }
395 }
396 return RET_OK;
397 }
398
DoQuantize(const FuncGraphPtr & old_graph,const converter::Flags * config)399 int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const converter::Flags *config) {
400 std::set<FuncGraphPtr> all_func_graphs{};
401 GetFuncGraphs(old_graph, &all_func_graphs);
402 for (auto &item : all_func_graphs) {
403 auto status = DoSingleGraphQuantize(item, config);
404 if (status != RET_OK) {
405 MS_LOG(ERROR) << "Do Quantize failed.";
406 return status;
407 }
408 }
409 return RET_OK;
410 }
411
TransformFuncGraph(const FuncGraphPtr & old_graph,const converter::Flags * config)412 FuncGraphPtr AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const converter::Flags *config) {
413 MS_ASSERT(old_graph != nullptr);
414 if (config == nullptr) {
415 MS_LOG(ERROR) << "config should be specified";
416 return nullptr;
417 }
418
419 auto status = RunConvertPass(old_graph, config);
420 if (status != RET_OK) {
421 MS_LOG(ERROR) << "Run convert pass failed.";
422 return nullptr;
423 }
424
425 if (!RunExternalPass(old_graph, registry::POSITION_BEGIN)) {
426 MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
427 return nullptr;
428 }
429
430 status = RunConstFoldPass(old_graph, config);
431 if (status != RET_OK) {
432 MS_LOG(ERROR) << "Run const fold pass failed.";
433 return nullptr;
434 }
435
436 if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
437 MS_LOG(WARNING) << "Run infershape opt pass failed.";
438 } else {
439 if (!RunOptimizerPass(old_graph, {"DecreaseTransposeAlgo"})) {
440 MS_LOG(ERROR) << "Run transpose opt pass failed.";
441 return nullptr;
442 }
443 }
444
445 auto reduce_act_pass = std::make_shared<opt::ReduceSameActPass>();
446 MS_CHECK_TRUE_RET(reduce_act_pass != nullptr, nullptr);
447 if (!reduce_act_pass->Run(old_graph)) {
448 MS_LOG(ERROR) << "Run reduce same act pass failed.";
449 return nullptr;
450 }
451
452 auto split_one_pass = std::make_shared<opt::SplitOnePass>();
453 MS_CHECK_TRUE_RET(split_one_pass != nullptr, nullptr);
454 if (!split_one_pass->Run(old_graph)) {
455 MS_LOG(ERROR) << "Run split one pass failed.";
456 return nullptr;
457 }
458
459 if (!config->disableFusion) {
460 status = RunFusionPass(old_graph, config);
461 if (status != RET_OK) {
462 MS_LOG(ERROR) << "Run fusion pass failed.";
463 return nullptr;
464 }
465 }
466
467 if (!RunExternalPass(old_graph, registry::POSITION_END)) {
468 MS_LOG(ERROR) << "Run external pass failed, place is END";
469 return nullptr;
470 }
471
472 if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
473 MS_LOG(WARNING) << "Run infershape opt pass failed.";
474 if (!RunOptimizerPass(old_graph, {"SpecifyGraphInputFormat"})) {
475 MS_LOG(ERROR) << "specify the input format of exported model failed.";
476 return nullptr;
477 }
478 } else {
479 if (!RunOptimizerPass(old_graph, {"SpecifyGraphInputFormat", "DecreaseTransposeAlgo"})) {
480 MS_LOG(ERROR) << "Run transpose opt pass failed.";
481 return nullptr;
482 }
483 }
484
485 status = RunGraphPass(old_graph, config);
486 if (status != RET_OK) {
487 MS_LOG(ERROR) << "Run convert pass failed.";
488 return nullptr;
489 }
490
491 status = RunParallelPass(old_graph, config);
492 if (status != RET_OK) {
493 MS_LOG(ERROR) << "Run convert pass failed.";
494 return nullptr;
495 }
496
497 status = DoQuantize(old_graph, config);
498 if (status != RET_OK) {
499 MS_LOG(ERROR) << "Do Quantize failed.";
500 return nullptr;
501 }
502
503 return old_graph;
504 }
505
StoreBuiltinPass(const converter::Flags * config)506 bool AnfTransform::StoreBuiltinPass(const converter::Flags *config) {
507 if (config == nullptr) {
508 MS_LOG(ERROR) << "config is nullptr";
509 return false;
510 }
511 auto fmk = config->fmk;
512 auto is_train = config->trainModel;
513 std::unordered_map<std::string, opt::PassPtr> passes = {
514 {"DumpGraph", std::make_shared<opt::DumpGraph>(config)},
515 {"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train)},
516 {"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train)},
517 {"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train)},
518 {"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train)},
519 {"SpecifyGraphInputFormat", std::make_shared<opt::SpecifyGraphInputFormat>(config->graphInputFormat)}};
520 bool succeed_store = true;
521 for (auto iter = passes.begin(); iter != passes.end(); ++iter) {
522 if (PassStorage::StorePass(iter->first, iter->second) != RET_OK) {
523 MS_LOG(ERROR) << "external pass name conflicts with that of internal pass, the pass name is " << iter->first
524 << ", please edit external pass name.";
525 succeed_store = false;
526 }
527 }
528 return succeed_store;
529 }
530
Transform(const FuncGraphPtr & main_graph,const converter::Flags * config)531 FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &main_graph, const converter::Flags *config) {
532 if (!StoreBuiltinPass(config)) {
533 MS_LOG(ERROR) << "store pass failed.";
534 return nullptr;
535 }
536 auto new_graph = TransformFuncGraph(main_graph, config);
537 if (new_graph == nullptr) {
538 MS_LOG(ERROR) << "optimizer failed.";
539 ReturnCode::GetSingleReturnCode()->UpdateReturnCode(RET_NULL_PTR);
540 }
541 return new_graph;
542 }
543 } // namespace mindspore::lite
544