• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/converter/anf_transform.h"
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <deque>
23 #include <map>
24 #include <tuple>
25 #include "nnacl/op_base.h"
26 #include "src/common/log_adapter.h"
27 #include "tools/converter/optimizer_manager.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "tools/optimizer/common/pass_manager_extends.h"
30 #include "ir/primitive.h"
31 #include "tools/optimizer/fusion/add_activation_fusion.h"
32 #include "tools/optimizer/fusion/affine_activation_fusion.h"
33 #include "tools/optimizer/fusion/affine_fusion.h"
34 #include "tools/optimizer/fusion/conv_biasadd_fusion.h"
35 #include "tools/optimizer/fusion/conv_activation_fusion.h"
36 #include "tools/optimizer/fusion/adjust_matmul_pass.h"
37 #include "tools/optimizer/fusion/conv_tuple_activation_fusion.h"
38 #include "tools/optimizer/fusion/conv_scale_fusion.h"
39 #include "tools/optimizer/fusion/conv_bn_fusion.h"
40 #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h"
41 #include "tools/optimizer/const_fold/constant_folding_fusion.h"
42 #include "tools/optimizer/fusion/hard_swish_fusion.h"
43 #include "tools/optimizer/fusion/norm_fusion.h"
44 #include "tools/optimizer/fusion/prelu_fusion.h"
45 #include "tools/optimizer/fusion/batchmatmul_fusion.h"
46 #include "tools/optimizer/fusion/batchnorm_to_scale_fusion.h"
47 #include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
48 #include "tools/optimizer/fusion/conv_conv_fusion.h"
49 #include "tools/optimizer/fusion/conv_pad_fusion.h"
50 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
51 #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
52 #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
53 #include "tools/optimizer/fusion/tensor_dot_fusion.h"
54 #include "tools/optimizer/fusion/multi_head_attention_fusion.h"
55 #include "tools/optimizer/fusion/encoder_layer_fusion.h"
56 #include "tools/optimizer/fusion/decoder_layer_fusion.h"
57 #include "tools/optimizer/fusion/glu_fusion.h"
58 #include "tools/optimizer/graph/unused_add_node_remove_pass.h"
59 #include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h"
60 #include "tools/optimizer/fusion/matmul_add_fusion.h"
61 #include "tools/optimizer/fusion/matmul_mul_fusion.h"
62 #include "tools/optimizer/fusion/mul_add_fusion.h"
63 #include "tools/optimizer/fusion/tf_gelu_fusion.h"
64 #include "tools/optimizer/fusion/onnx_gelu_fusion.h"
65 #include "tools/optimizer/fusion/squeeze_fusion.h"
66 #include "tools/optimizer/fusion/reshape_reshape_fusion.h"
67 #include "tools/optimizer/fusion/reshape_transpose_fusion.h"
68 #include "tools/optimizer/fusion/transpose_matmul_fusion.h"
69 #include "tools/optimizer/fusion/scale_activation_fusion.h"
70 #include "tools/optimizer/fusion/scale_scale_fusion.h"
71 #include "tools/optimizer/fusion/resize_fusion.h"
72 #include "tools/optimizer/fusion/fullconnected_fusion.h"
73 #include "tools/optimizer/fusion/fullconnected_add_fusion.h"
74 #include "tools/optimizer/fusion/add_concat_activation_fusion.h"
75 #include "tools/optimizer/fusion/matmul_activation_fusion.h"
76 #include "tools/optimizer/fusion/mul_activation_fusion.h"
77 #include "tools/optimizer/fusion/activation_fusion.h"
78 #include "tools/optimizer/fusion/reshape_reduce_fusion.h"
79 #include "tools/optimizer/fusion/add_layernorm_fusion.h"
80 #include "tools/optimizer/graph/add_tensor_array.h"
81 #include "tools/optimizer/graph/redundant_op_remove_pass.h"
82 #include "tools/optimizer/graph/clip_convert_activation_pass.h"
83 #include "tools/optimizer/graph/mul_constant_pass.h"
84 #include "tools/optimizer/graph/update_conv2d_param_pass.h"
85 #include "tools/optimizer/graph/infershape_pass.h"
86 #include "tools/optimizer/graph/slice_prepose_pass.h"
87 #include "tools/optimizer/graph/control_flow_pass.h"
88 #include "tools/optimizer/graph/reduce_same_act_pass.h"
89 #include "tools/optimizer/graph/split_one_pass.h"
90 #include "tools/optimizer/graph/decrease_transpose_algo.h"
91 #include "tools/optimizer/graph/special_node_postprocess.h"
92 #include "tools/optimizer/graph/specify_graph_input_format.h"
93 #include "tools/optimizer/graph/dump_graph.h"
94 #include "tools/optimizer/graph/eliminate_redundant_cast_pass.h"
95 #include "tools/converter/quantizer/quantization_optimizer.h"
96 #include "tools/optimizer/parallel/split_strategy.h"
97 #include "tools/optimizer/parallel/spliter.h"
98 #include "tools/optimizer/fisson/iter_node_outputs.h"
99 #include "tools/optimizer/fisson/node_out_shapes.h"
100 #include "tools/optimizer/parallel/parallel_pass.h"
101 #include "include/registry/pass_registry.h"
102 #include "tools/optimizer/fisson/multi_conv_split_pass.h"
103 #include "tools/optimizer/fusion/transpose_fusion.h"
104 #include "tools/optimizer/format/to_nchw_format.h"
105 #include "tools/optimizer/graph/int64_cast_int32_pass.h"
106 #include "tools/optimizer/graph/input_data_type_trans_pass.h"
107 #include "tools/optimizer/fusion/cast_fusion.h"
108 #include "tools/optimizer/format/to_nhwc_format.h"
109 #include "tools/optimizer/fusion/expanddims_reshape_fusion.h"
110 #include "tools/optimizer/fusion/reduce_same_op_in_horizon.h"
111 #include "tools/optimizer/fusion/reshape_shape_fusion.h"
112 #include "tools/optimizer/fusion/transpose_gather_fusion.h"
113 #ifndef ENABLE_CLOUD_FUSION_INFERENCE
114 #include "tools/converter/adapter/acl/acl_pass.h"
115 #endif
116 #include "src/common/log_util.h"
117 #include "src/common/string_utils.h"
118 #include "src/common/config_infos.h"
119 #include "tools/graph_kernel/converter/graph_kernel_optimization.h"
120 #include "tools/optimizer/fusion/groupnorm_fusion.h"
121 #include "tools/optimizer/fusion/mul_reduce_fusion.h"
122 #include "tools/optimizer/fusion/reshape_like_operator_ablation.h"
123 #include "tools/optimizer/fusion/concat_concat_fusion.h"
124 #include "tools/optimizer/fusion/strided_slice_fusion.h"
125 #include "tools/optimizer/fusion/reduce_stack_fusion.h"
126 #include "tools/optimizer/fusion/remove_transitivity_op.h"
127 #include "tools/converter/import/cast_op_adjust.h"
128 #include "tools/converter/adapter/acl/plugin/acl_pass_plugin.h"
129 #include "tools/converter/quantizer/quant_helper/qat_transform.h"
130 #include "tools/converter/parser/conv2d_transpose_input_adjust.h"
131 #include "tools/converter/parser/parser_utils.h"
132 #include "tools/converter/parser/unify_format.h"
133 #include "include/backend/optimizer/graph_optimizer.h"
134 #include "tools/optimizer/fusion/squeeze_expanddims_fusion.h"
135 #include "mindspore/core/ops/op_name.h"
136 #include "tools/common/string_util.h"
137 #include "src/common/common.h"
138 #include "tools/optimizer/graph/miniaturization_pass.h"
139 #include "tools/optimizer/graph/scalar_op_pass.h"
140 #include "tools/optimizer/fusion/tile_matmul_fusion.h"
141 #include "tools/optimizer/fusion/flash_attention_fusion_for_custom.h"
142 #include "tools/optimizer/fusion/gegluv2_fusion.h"
143 #include "tools/optimizer/fusion/ffn_fusion.h"
144 #include "tools/optimizer/graph/make_list_pass.h"
145 #include "tools/optimizer/fusion/flash_attention_fusion.h"
146 #include "tools/optimizer/fusion/groupnormsilu_fusion.h"
147 #include "tools/optimizer/fusion/adjust_resize_dims_pass.h"
148 
149 using std::string;
150 namespace mindspore::lite {
151 namespace {
152 constexpr auto kOriginalFmkType = "original_fmk_type";
153 constexpr auto kConverterInputShape = "converter_input_shape";
154 
TransInputShapesToString(const std::map<std::string,std::vector<int64_t>> & shapes)155 std::string TransInputShapesToString(const std::map<std::string, std::vector<int64_t>> &shapes) {
156   std::stringstream str_stream;
157   size_t shape_index = 0;
158   for (auto &item : shapes) {
159     str_stream << item.first << ":";
160     auto &shape = item.second;
161     for (size_t d = 0; d < shape.size(); d++) {
162       str_stream << shape[d];
163       if (d + 1 != shape.size()) {
164         str_stream << ",";
165       }
166     }
167     if (shape_index + 1 != shapes.size()) {
168       str_stream << ";";
169     }
170     shape_index++;
171   }
172   return str_stream.str();
173 }
174 
TransStringToInputShapes(const std::string & shapes_str)175 std::map<std::string, std::vector<int64_t>> TransStringToInputShapes(const std::string &shapes_str) {
176   std::map<std::string, std::vector<int64_t>> shapes;
177   auto shapes_pairs = lite::SplitStringToVector(shapes_str, ';');
178   for (auto &kv_str : shapes_pairs) {
179     auto pos = kv_str.rfind(':');
180     if (pos == std::string::npos || pos + 1 == kv_str.size()) {
181       MS_LOG(ERROR) << "Invalid input shapes string: " << shapes_str;
182       return {};
183     }
184     auto name = kv_str.substr(0, pos);
185     auto shape_str = kv_str.substr(pos + 1);
186     auto shape_dims_str = lite::SplitStringToVector(shape_str, ',');
187     std::vector<int64_t> shape;
188     shape.reserve(shape_dims_str.size());
189     for (auto &dim_str : shape_dims_str) {
190       int dim = 0;
191       if (!lite::ConvertIntNum(dim_str, &dim)) {
192         MS_LOG(ERROR) << "Invalid input shapes string: " << shapes_str;
193         return {};
194       }
195       shape.push_back(dim);
196     }
197     shapes[name] = shape;
198   }
199   return shapes;
200 }
201 }  // namespace
202 
203 AnfTransform::AnfTransform() = default;
204 
205 AnfTransform::~AnfTransform() = default;
206 
MarkTrainInputOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)207 STATUS AnfTransform::MarkTrainInputOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
208   for (size_t i = 1; i < cnode->size(); i++) {
209     auto input_node = cnode->input(i);
210     if (!utils::isa<CNodePtr>(input_node)) {
211       continue;
212     }
213     auto input_cnode = utils::cast<CNodePtr>(input_node);
214     MS_CHECK_TRUE_RET(input_cnode != nullptr, RET_ERROR);
215     auto prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
216     if (prim == nullptr) {
217       MS_LOG(DEBUG) << "Primitive is nullptr.";
218       continue;
219     }
220     (void)prim->AddAttr("trainOp", MakeValue(true));
221   }
222   return RET_OK;
223 }
224 
MarkTrainWeightSharingOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)225 STATUS AnfTransform::MarkTrainWeightSharingOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
226   auto node_list = TopoSort(func_graph->get_return());
227   for (auto &node : node_list) {
228     if (!utils::isa<CNodePtr>(node)) {
229       continue;
230     }
231     auto graph_cnode = utils::cast<CNodePtr>(node);
232     MS_CHECK_TRUE_RET(graph_cnode != nullptr, RET_ERROR);
233     auto graph_prim = GetValueNode<PrimitivePtr>(graph_cnode->input(0));
234     if (graph_prim == nullptr) {
235       MS_LOG(DEBUG) << "Primitive is nullptr.";
236       continue;
237     }
238     for (size_t i = 1; i < graph_cnode->size(); i++) {
239       for (size_t j = 1; j < cnode->size(); j++) {
240         if ((graph_cnode->input(i) == cnode->input(j)) && utils::isa<Parameter>(cnode->input(j))) {
241           (void)graph_prim->AddAttr("trainOp", MakeValue(true));
242         }
243       }
244     }
245   }
246   return RET_OK;
247 }
248 
MarkTrainOp(const FuncGraphPtr & func_graph)249 STATUS AnfTransform::MarkTrainOp(const FuncGraphPtr &func_graph) {
250   auto node_list = TopoSort(func_graph->get_return());
251   for (auto &node : node_list) {
252     if (!utils::isa<CNodePtr>(node)) {
253       continue;
254     }
255     auto cnode = utils::cast<CNodePtr>(node);
256     MS_CHECK_TRUE_RET(cnode != nullptr, RET_ERROR);
257     auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
258     if (prim == nullptr) {
259       MS_LOG(DEBUG) << "Primitive is nullptr.";
260       continue;
261     }
262     if (opt::IsTrainOp(cnode)) {
263       (void)prim->AddAttr("trainOp", MakeValue(true));
264       auto status = MarkTrainInputOp(func_graph, cnode);
265       if (status != RET_OK) {
266         MS_LOG(ERROR) << "MarkTrainInputOp failed.";
267         return RET_ERROR;
268       }
269       status = MarkTrainWeightSharingOp(func_graph, cnode);
270       if (status != RET_OK) {
271         MS_LOG(ERROR) << "MarkTrainWeightSharingOp failed.";
272         return RET_ERROR;
273       }
274     }
275   }
276   return RET_OK;
277 }
InitFusions(const std::shared_ptr<ConverterPara> & param)278 std::vector<opt::PassPtr> InitFusions(const std::shared_ptr<ConverterPara> &param) {
279   // The training model only does the fusion of the inference part
280   // remove quantdtype when awaretraining
281   std::vector<opt::PassPtr> fusions{std::make_shared<opt::AddConcatActivationFusion>(),
282                                     std::make_shared<opt::HardSwishFusion>(),
283                                     std::make_shared<opt::PReluFusion>(),
284                                     std::make_shared<opt::SqueezeFusion>(),
285                                     std::make_shared<opt::TransposeFusion>(),
286                                     std::make_shared<opt::CastFusionPass>(),
287                                     std::make_shared<opt::ReshapeReshapeFusion>(),
288                                     std::make_shared<opt::ReshapeTransposeFusion>(),
289                                     std::make_shared<opt::ConvBiasaddFusion>(),
290                                     std::make_shared<opt::ConvBatchNormFusion>(param->fmk_type),
291                                     std::make_shared<opt::ConvScaleFusion>(param->fmk_type),
292                                     std::make_shared<opt::GroupNormFusion>(),
293                                     std::make_shared<opt::TfNormFusion>(),
294                                     std::make_shared<opt::OnnxLayerNormFusion>(),
295                                     std::make_shared<opt::OnnxLayerNormFusion2>(),
296                                     std::make_shared<opt::BatchMatMulFusion>(),
297                                     std::make_shared<opt::BatchNormToScaleFusion>(),
298                                     std::make_shared<opt::SigmoidMulFusion>(),
299                                     std::make_shared<opt::ActivationFusion>(),
300                                     std::make_shared<opt::ConvActivationFusion>(param),
301                                     std::make_shared<opt::ConvTupleGetItemFusion>(),
302                                     std::make_shared<opt::ConvTupleActivationFusion>(),
303                                     std::make_shared<opt::TfliteLstmCellFusion>(),
304                                     std::make_shared<opt::TfLstmCellFusion>(),
305                                     std::make_shared<opt::TfBidirectionGruFusion>(),
306                                     std::make_shared<opt::TfGeLUFusion>(),
307                                     std::make_shared<opt::OnnxGeLUFusion>(),
308                                     std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>(),
309                                     std::make_shared<opt::GLUFusion>(),
310                                     std::make_shared<opt::ResizeFusion1>(),
311                                     std::make_shared<opt::ResizeFusion2>(),
312                                     std::make_shared<opt::ConstFoldPass>(param->fmk_type, param->train_model),
313                                     std::make_shared<opt::AffineFusion>(),
314                                     std::make_shared<opt::AffineActivationFusion>(),
315                                     std::make_shared<opt::ConvConvFusion>(),
316                                     std::make_shared<opt::ConvPadFusion>(),
317                                     std::make_shared<opt::MatMulAddFusion>(),
318                                     std::make_shared<opt::MatMulMulFusion>(),
319                                     std::make_shared<opt::TransposeMatMulFusion>(),
320                                     std::make_shared<opt::MulAddFusion>(),
321                                     std::make_shared<opt::ScaleActivationFusion>(),
322                                     std::make_shared<opt::ScaleScaleFusion>(),
323                                     std::make_shared<opt::FullConnectedFusion>(),
324                                     std::make_shared<opt::FullconnectedAddFusion>(),
325                                     std::make_shared<opt::TensorDotFusion>(),
326                                     std::make_shared<opt::MatMulActivationFusion>(param),
327                                     std::make_shared<opt::MulActivationFusion>(),
328                                     std::make_shared<opt::AddActivationFusion>(),
329                                     std::make_shared<opt::ExpandDimsReshapeFusion>(),
330                                     std::make_shared<opt::SqueezeExpandDimsFusion>(),
331                                     std::make_shared<opt::TileMatMulFusion>()};
332   if (param->optimize_transformer) {
333     fusions.push_back(std::make_shared<opt::MultiHeadAttentionFusion>());
334     fusions.push_back(std::make_shared<opt::EncoderLayerFusion>(true));
335     fusions.push_back(std::make_shared<opt::EncoderLayerFusion>(false));
336     fusions.push_back(std::make_shared<opt::DecoderLayerFusion>());
337   }
338   return fusions;
339 }
340 
RunFusionPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)341 int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
342   auto status = MarkTrainOp(old_graph);
343   if (status != RET_OK) {
344     MS_LOG(ERROR) << "MarkTrainOp failed.";
345     return RET_ERROR;
346   }
347   auto optimizer = std::make_shared<opt::GraphOptimizer>();
348   CHECK_NULL_RETURN(optimizer);
349   auto fusion_pm = std::make_shared<opt::LitePassManager>("anf fusion pass manager", false);
350   CHECK_NULL_RETURN(fusion_pm);
351 
352   auto fusions = InitFusions(param);
353   for (size_t index = 0; index < fusions.size(); index++) {
354     auto pass_ptr = fusions.at(index);
355     MS_CHECK_TRUE_RET(pass_ptr != nullptr, RET_ERROR);
356     auto pass_name = pass_ptr->name();
357     if (param->fusion_blacklists.find(pass_name) != param->fusion_blacklists.end()) {
358       MS_LOG(INFO) << "Disable fusion: " << pass_name;
359       continue;
360     }
361     fusion_pm->AddPass(pass_ptr);
362   }
363   optimizer->AddPassManager(fusion_pm);
364   if (optimizer->Optimize(old_graph) == nullptr) {
365     MS_LOG(ERROR) << "run op fusion failed.";
366     return RET_ERROR;
367   }
368 
369   // the following pass needs to check the return value.
370   fusions = {std::make_shared<opt::ReduceSameOpInHorizon>(param), std::make_shared<opt::ReshapeReduceFusion>(),
371              std::make_shared<opt::AblateReshapeLikeOp>(),        std::make_shared<opt::MulReduceFusion>(),
372              std::make_shared<opt::ConcatConcatFusion>(),         std::make_shared<opt::ReduceStackFusion>(),
373              std::make_shared<opt::RemoveTransitivityOp>(),       std::make_shared<opt::StridedSliceFusion>(),
374              std::make_shared<opt::RemoveTransitivityOp>(),       std::make_shared<opt::ReshapeShapeFusion>(),
375              std::make_shared<opt::TransposeGatherFusion>()};
376   for (auto &pass : fusions) {
377     MS_CHECK_TRUE_MSG(pass != nullptr, RET_ERROR, "pass is a nullptr.");
378     if (param->fusion_blacklists.find(pass->name()) != param->fusion_blacklists.end()) {
379       MS_LOG(INFO) << "Disable fusion: " << pass->name();
380       continue;
381     }
382     if (!pass->Run(old_graph)) {
383       MS_LOG(ERROR) << pass->name() << " running failed.";
384       return RET_ERROR;
385     }
386   }
387   return RET_OK;
388 }
389 
RunParallelPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)390 int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
391   MS_LOG(DEBUG) << "Run ParallelPass start";
392   if (param->train_model || param->parallel_split_config.parallel_split_type_ == SplitNo) {
393     return RET_OK;
394   }
395   if (param->parallel_split_config.parallel_split_type_ == SplitByUserRatio) {
396     auto optimizer = std::make_shared<opt::GraphOptimizer>();
397     CHECK_NULL_RETURN(optimizer);
398     auto graph_inputs = old_graph->get_inputs();
399     opt::SplitMode split_mode = opt::NoSplit;
400     for (const auto &graph_input : graph_inputs) {
401       if (utils::isa<Parameter>(graph_input)) {
402         auto input_parameter = dyn_cast<Parameter>(graph_input);
403         MSLITE_CHECK_PTR(input_parameter->Shape());
404         auto shape_ptr = input_parameter->Shape()->cast<abstract::ShapePtr>();
405         MSLITE_CHECK_PTR(shape_ptr);
406         auto batch = shape_ptr->shape().front();
407         if (batch > opt::kDefaultBatch) {
408           split_mode = opt::SplitN;
409         } else {
410           split_mode = opt::SplitH;
411         }
412         break;
413       }
414     }
415     // 1. deal with split strategy
416     std::unordered_map<std::string, opt::SplitStrategy> split_strategys = opt::ParserSplitStrategy(
417       param->parallel_split_config.parallel_compute_rates_, param->parallel_split_config.parallel_devices_, split_mode);
418     if (split_strategys.empty()) {
419       MS_LOG(WARNING) << "No valid split_strategy. Run convert without split";
420       return RET_OK;
421     }
422     opt::Spliter::GetInstance()->RecordGraphInfo(old_graph);
423     auto parallel_pm = std::make_shared<opt::LitePassManager>("anf parallel pass manager", true);
424     CHECK_NULL_RETURN(parallel_pm);
425     // 2. preceding parallel pass
426     parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
427     parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
428     std::set<int, opt::IntCompare> match_multi_numbers = opt::Spliter::GetInstance()->graph_match_multi_numbers();
429     int max_match_number = *match_multi_numbers.begin();
430     // we do not deal with single conv node
431     for (int match_number = max_match_number; match_number > opt::kDefaultBatch; --match_number) {
432       // 3. multi_conv parallel pass
433       parallel_pm->AddPass(std::make_shared<opt::MultiConvSplitPass>(split_strategys, param->fmk_type, match_number));
434       parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
435       parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
436     }
437     optimizer->AddPassManager(parallel_pm);
438     if (optimizer->Optimize(old_graph) == nullptr) {
439       MS_LOG(ERROR) << "run const fold failed.";
440       return RET_ERROR;
441     }
442   }
443   MS_LOG(DEBUG) << "Run ParallelPass end";
444   return RET_OK;
445 }
446 
RunGraphPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)447 int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
448   auto optimizer = std::make_shared<opt::GraphOptimizer>();
449   CHECK_NULL_RETURN(optimizer);
450   auto graph_pm = std::make_shared<opt::LitePassManager>("anf graph pass manager", true);
451   CHECK_NULL_RETURN(graph_pm);
452   if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
453       param->fmk_type == converter::kFmkTypeOnnx) {
454     graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
455   }
456   auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
457   CHECK_NULL_RETURN(slice_prepose_pass);
458   slice_prepose_pass->SetFmkType(param->fmk_type);
459   graph_pm->AddPass(slice_prepose_pass);
460   optimizer->AddPassManager(graph_pm);
461   if (optimizer->Optimize(old_graph) == nullptr) {
462     MS_LOG(ERROR) << "run  graph pass failed.";
463     return RET_ERROR;
464   }
465   return RET_OK;
466 }
467 
RunConvertPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)468 int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
469   if (param->device.find("Ascend") != std::string::npos) {
470     auto acl_pass_ptr = opt::AclPassPlugin::CreateAclPass(param);
471     if (acl_pass_ptr == nullptr) {
472       MS_LOG(ERROR) << "Failed to create acl pass";
473       return RET_ERROR;
474     }
475     if (!acl_pass_ptr->Run(old_graph)) {
476       MS_LOG(ERROR) << "Acl pass failed.";
477       return RET_ERROR;
478     }
479   }
480   // adjust for conv2d_transpose
481   if (!(param->no_fusion && param->save_type == kMindIR)) {
482     std::set<FuncGraphPtr> all_func_graphs = {};
483     GetAllFuncGraph(old_graph, &all_func_graphs);
484     auto conv2d_transpose_adjust = std::make_shared<Conv2DTransposeInputAdjust>();
485     MS_CHECK_TRUE_MSG(conv2d_transpose_adjust != nullptr, RET_NULL_PTR, "conv2d_transpose_adjust is nullptr.");
486     for (auto sub_graph : all_func_graphs) {
487       if (!conv2d_transpose_adjust->Run(old_graph)) {
488         MS_LOG(ERROR) << "adjust conv2d_transpose failed";
489         return RET_ERROR;
490       }
491     }
492   }
493   auto optimizer = std::make_shared<opt::GraphOptimizer>();
494   CHECK_NULL_RETURN(optimizer);
495   auto convert_pm = std::make_shared<opt::LitePassManager>("anf graph convert pass manager", true);
496   CHECK_NULL_RETURN(convert_pm);
497   convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
498   convert_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
499   convert_pm->AddPass(std::make_shared<opt::CastOpAdjust>());
500   convert_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
501   optimizer->AddPassManager(convert_pm);
502   if (optimizer->Optimize(old_graph) == nullptr) {
503     MS_LOG(ERROR) << "run graph convert pass failed.";
504     return RET_ERROR;
505   }
506   return RET_OK;
507 }
508 
RunConstFoldPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)509 int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
510   auto optimizer = std::make_shared<opt::GraphOptimizer>();
511   auto const_fold_pm = std::make_shared<opt::LitePassManager>("const fold fusion pass manager", false);
512   CHECK_NULL_RETURN(optimizer);
513   CHECK_NULL_RETURN(const_fold_pm);
514   if (param->train_model) {
515     const_fold_pm->AddPass(std::make_shared<opt::MiniaturizationPass>());
516   }
517   const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
518   if (!param->train_model) {
519     const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(param->fmk_type, param->train_model));
520   }
521   const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
522   const_fold_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>(), param->fusion_blacklists);
523   optimizer->AddPassManager(const_fold_pm);
524   if (optimizer->Optimize(old_graph) == nullptr) {
525     MS_LOG(ERROR) << "run const fold failed.";
526     return RET_ERROR;
527   }
528   return RET_OK;
529 }
530 
RunInt64CastInt32Pass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)531 int AnfTransform::RunInt64CastInt32Pass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
532   auto optimizer = std::make_shared<opt::GraphOptimizer>();
533   auto int64_cast_int32_pm = std::make_shared<opt::LitePassManager>("int64 cast to int32 pass manager", false);
534   CHECK_NULL_RETURN(optimizer);
535   CHECK_NULL_RETURN(int64_cast_int32_pm);
536   int64_cast_int32_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
537   int64_cast_int32_pm->AddPass(std::make_shared<opt::Int64CastInt32Pass>());
538   int64_cast_int32_pm->AddPass(std::make_shared<opt::CastFusionPass>());
539 
540   optimizer->AddPassManager(int64_cast_int32_pm);
541   if (optimizer->Optimize(old_graph) == nullptr) {
542     MS_LOG(ERROR) << "run const fold failed.";
543     return RET_ERROR;
544   }
545   return RET_OK;
546 }
547 
RunDecreaseTransposePass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)548 int RunDecreaseTransposePass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
549   MS_ASSERT(old_graph != nullptr && param != nullptr);
550   auto pass = std::make_shared<opt::DecreaseTransposeAlgo>(param->fmk_type, param->train_model, false);
551   MS_CHECK_TRUE_RET(pass != nullptr, RET_ERROR);
552   if (!pass->Run(old_graph)) {
553     MS_LOG(ERROR) << "Run DecreaseTransposeAlgo pass failed";
554     return RET_ERROR;
555   }
556 
557   auto optimizer = std::make_shared<opt::GraphOptimizer>();
558   auto decrease_trans_pm = std::make_shared<opt::LitePassManager>("decrease transpose fusion pass manager", false);
559   CHECK_NULL_RETURN(optimizer);
560   CHECK_NULL_RETURN(decrease_trans_pm);
561   std::vector<opt::PassPtr> fusions = {std::make_shared<opt::ReshapeTransposeFusion>(),
562                                        std::make_shared<opt::TransposeFusion>()};
563   (void)std::for_each(fusions.begin(), fusions.end(), [&decrease_trans_pm, &param](opt::PassPtr fusion) {
564     if (fusion != nullptr && param->fusion_blacklists.find(fusion->name()) == param->fusion_blacklists.end()) {
565       decrease_trans_pm->AddPass(fusion);
566     }
567   });
568   optimizer->AddPassManager(decrease_trans_pm);
569   if (optimizer->Optimize(old_graph) == nullptr) {
570     MS_LOG(ERROR) << "run decrease transpose failed.";
571     return RET_ERROR;
572   }
573   return RET_OK;
574 }
575 
CheckExternalExtension(const std::shared_ptr<ConverterPara> & param)576 bool AnfTransform::CheckExternalExtension(const std::shared_ptr<ConverterPara> &param) {
577   return (!param->plugins_path.empty() && param->commonQuantParam.quant_type != quant::QUANT_NONE);
578 }
579 
DoQuantize(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)580 int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
581   quant::QuantizationOptimizer quantization_optimizer(param);
582   auto ret = quantization_optimizer.Run(old_graph);
583   if (ret != RET_OK) {
584     MS_LOG(ERROR) << "Post training quantization failed.";
585     return ret;
586   }
587   return RET_OK;
588 }
589 
DoFormatForMindIR(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)590 int AnfTransform::DoFormatForMindIR(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
591   if (param->save_type != kMindIR) {
592     return RET_OK;
593   }
594   if (param->no_fusion || param->device.find("Ascend") == std::string::npos) {
595     MS_LOG(INFO) << "export MindIR, run pass ToNCHWFormat";
596     if (!RunOptimizerPass(old_graph, {"ToNCHWFormat", "DecreaseTransposeAlgo"})) {
597       MS_LOG(ERROR) << "Run ToNCHWFormat pass failed";
598       return RET_ERROR;
599     }
600   }
601   old_graph->set_attr(kOriginalFmkType, MakeValue(static_cast<int32_t>(param->fmk_type)));
602 
603   return RET_OK;
604 }
605 
RunFormatTrans(const FuncGraphPtr & old_graph)606 int AnfTransform::RunFormatTrans(const FuncGraphPtr &old_graph) {
607   auto value = old_graph->get_attr(ops::kFormat);
608   if (value != nullptr && GetValue<int64_t>(value) == mindspore::NHWC) {
609     return RET_OK;
610   }
611   if (!RunOptimizerPass(old_graph, {"ToNHWCFormat", "DecreaseTransposeAlgo"})) {
612     MS_LOG(ERROR) << "Run ToNHWCFormat pass failed";
613     return RET_ERROR;
614   }
615   return RET_OK;
616 }
617 
RunEliminateRedundantPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)618 bool RunEliminateRedundantPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
619   if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
620     MS_LOG(WARNING) << "Run infershape opt pass failed.";
621   } else if (!RunOptimizerPass(old_graph, {"DecreaseTransposeAlgo"})) {
622     MS_LOG(ERROR) << "Run transpose opt pass failed.";
623     return false;
624   }
625 
626   auto optimizer = std::make_shared<opt::GraphOptimizer>();
627   MS_CHECK_TRUE_RET(optimizer != nullptr, false);
628   auto eliminate_pm = std::make_shared<opt::LitePassManager>("anf graph eliminate redundant pass manager", true);
629   MS_CHECK_TRUE_RET(eliminate_pm != nullptr, false);
630   eliminate_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
631   eliminate_pm->AddPass(std::make_shared<opt::EliminateRedundantCastPass>(param->fmk_type, param->train_model));
632   eliminate_pm->AddPass(std::make_shared<opt::ReduceSameActPass>());
633   eliminate_pm->AddPass(std::make_shared<opt::SplitOnePass>());
634   eliminate_pm->AddPass(std::make_shared<opt::MulConstantPass>());
635   optimizer->AddPassManager(eliminate_pm);
636   if (optimizer->Optimize(old_graph) == nullptr) {
637     MS_LOG(ERROR) << "run graph convert pass failed.";
638     return false;
639   }
640   return true;
641 }
642 
ProcOnlineTransform(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)643 STATUS AnfTransform::ProcOnlineTransform(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
644   if (!RunOptimizerPass(old_graph, {"RemoveRedundantOpPass", "InferShapePass", "ConstFoldPass"})) {
645     MS_LOG(WARNING) << "Run infershape opt pass failed.";
646   }
647   auto status = DoFormatForMindIR(old_graph, param);
648   if (status != RET_OK) {
649     MS_LOG(ERROR) << "Do format for mindir failed.";
650     return lite::RET_ERROR;
651   }
652   if (!param->input_shape.empty()) {
653     auto graph_inputs = old_graph->get_inputs();
654     std::map<std::string, std::vector<int64_t>> input_shape;
655     for (auto &input : graph_inputs) {
656       auto abstract = input->abstract();
657       if (abstract) {
658         input_shape[abstract->name()] = opt::GetAnfNodeOutputShape(input, 0);
659       }
660     }
661     old_graph->set_attr(kConverterInputShape, MakeValue(TransInputShapesToString(input_shape)));
662   }
663   return lite::RET_OK;
664 }
665 
RunPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)666 int AnfTransform::RunPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
667   auto status = RunConvertPass(old_graph, param);
668   if (status != RET_OK) {
669     MS_LOG(ERROR) << "Run convert pass failed.";
670     return RET_ERROR;
671   }
672   if (!RunExternalPass(old_graph, registry::POSITION_BEGIN)) {
673     MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
674     return RET_ERROR;
675   }
676 
677   status = RunInt64CastInt32Pass(old_graph, param);
678   if (status != RET_OK) {
679     MS_LOG(ERROR) << "RunInt64CastInt32Pass failed.";
680     return RET_ERROR;
681   }
682   status = RunConstFoldPass(old_graph, param);
683   if (status != RET_OK) {
684     MS_LOG(ERROR) << "Run const fold pass failed.";
685     return RET_ERROR;
686   }
687 
688   if (!RunEliminateRedundantPass(old_graph, param)) {
689     MS_LOG(ERROR) << "Run elimination of redundant pass failed.";
690     return RET_ERROR;
691   }
692 
693   if (!param->no_fusion) {
694     status = RunFusionPass(old_graph, param);
695     if (status != RET_OK) {
696       MS_LOG(ERROR) << "Run fusion pass failed.";
697       return RET_ERROR;
698     }
699   }
700 
701   if (!RunExternalPass(old_graph, registry::POSITION_END)) {
702     MS_LOG(ERROR) << "Run external pass failed, place is END";
703     return RET_ERROR;
704   }
705 
706   if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
707     MS_LOG(WARNING) << "Run infershape opt pass failed.";
708     status = RunOptimizerPass(old_graph, {"SpecialNodePostProcess"}) ? RET_OK : RET_ERROR;
709   } else {
710     status =
711       RunOptimizerPass(old_graph, {"SpecialNodePostProcess"}) ? RunDecreaseTransposePass(old_graph, param) : RET_ERROR;
712   }
713   if (status != RET_OK) {
714     MS_LOG(ERROR) << "Run transpose opt pass failed.";
715     return RET_ERROR;
716   }
717 
718   if (CheckExternalExtension(param)) {
719     MS_LOG(ERROR) << "Unsupported external extension with quantization.";
720     return RET_ERROR;
721   }
722   // QATTransform will infer all subgraphs and should be executed before ControlFlowPass.
723   // After ControlFlowPass, there will be some ops that cannot be handled in the main graph, therefore, The
724   // InferShapePass cannot be executed.
725   auto qat_transform = quant::QATTransform(old_graph, param);
726   status = qat_transform.Transform();
727   if (status != RET_OK) {
728     MS_LOG(ERROR) << "Do QATTransform failed.";
729     return RET_ERROR;
730   }
731 
732   status = RunGraphPass(old_graph, param);
733   if (status != RET_OK) {
734     MS_LOG(ERROR) << "Run convert pass failed.";
735     return RET_ERROR;
736   }
737 
738   status = RunParallelPass(old_graph, param);
739   if (status != RET_OK) {
740     MS_LOG(ERROR) << "Run convert pass failed.";
741     return RET_ERROR;
742   }
743   return RET_OK;
744 }
745 
TransformFuncGraph(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)746 STATUS AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> &param) {
747   MS_ASSERT(old_graph != nullptr);
748   MS_ASSERT(param != nullptr);
749   if (param->no_fusion && param->save_type == kMindIR) {  // converter, online
750     if (ProcOnlineTransform(old_graph, param) != lite::RET_OK) {
751       MS_LOG(ERROR) << "Proc online transform failed.";
752       return RET_ERROR;
753     }
754     auto status = DoQuantize(old_graph, param);
755     if (status != RET_OK) {
756       MS_LOG(ERROR) << "Do Quantize failed.";
757       return RET_ERROR;
758     }
759     return RET_OK;
760   }
761   auto value = old_graph->get_attr(kIsOptimized);
762   if (param->is_runtime_converter) {  // load online
763     if (value != nullptr) {           // other models converted to MindIR
764       auto status = RunFormatTrans(old_graph);
765       if (status != RET_OK) {
766         MS_LOG(ERROR) << "Run format trans failed";
767         return status;
768       }
769     }
770   }
771 
772   if (RunPass(old_graph, param) != RET_OK) {
773     MS_LOG(ERROR) << "Proc online transform failed.";
774     return RET_ERROR;
775   }
776 
777   auto status = DoQuantize(old_graph, param);
778   if (status != RET_OK) {
779     MS_LOG(ERROR) << "Do Quantize failed.";
780     return RET_ERROR;
781   }
782 
783 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
784   if (param->device.find("Ascend") == std::string::npos) {
785     if (GraphKernelOptimize(old_graph, param) != RET_OK) {
786       MS_LOG(ERROR) << "Do graphkernel optimization failed.";
787       return RET_ERROR;
788     }
789   }
790 #endif
791 
792   status = DoFormatForMindIR(old_graph, param);
793   if (status != RET_OK) {
794     return RET_ERROR;
795   }
796   return RET_OK;
797 }
798 
StoreBuiltinPass(const std::shared_ptr<ConverterPara> & param)799 bool AnfTransform::StoreBuiltinPass(const std::shared_ptr<ConverterPara> &param) {
800   if (param == nullptr) {
801     MS_LOG(ERROR) << "param is nullptr";
802     return false;
803   }
804   auto fmk = param->fmk_type;
805   auto is_train = param->train_model;
806 
807   // pass_name, pass and boolean value to indicate whether can be called by external extension,
808   std::vector<std::tuple<std::string, opt::PassPtr, bool>> pass_infos = {
809     {"DumpGraph", std::make_shared<opt::DumpGraph>(param), true},
810     {"RemoveRedundantOpPass", std::make_shared<opt::RemoveRedundantOpPass>(param->train_model), false},
811     {"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train, param->save_type), true},
812     {"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train, param->save_type), true},
813     {"ConstFoldPass", std::make_shared<opt::ConstFoldPass>(fmk, is_train), true},
814     {"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train), false},
815     {"DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>(), false},
816     {"SpecialNodePostProcess", std::make_shared<opt::SpecialNodePostProcess>(), false},
817     {"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train), true},
818     {"RemoveUnusedAddNodePass", std::make_shared<opt::RemoveUnusedAddNodePass>(), false},
819     {"AdjustResizeDimsPass", std::make_shared<opt::AdjustResizeDimsPass>(), false},
820     {"ScalarOpPass", std::make_shared<opt::ScalarOpPass>(), true},
821     {"FlashAttentionFusionForCustom",
822      std::make_shared<opt::FlashAttentionFusionForCustom>(param->aclModelOptionCfgParam.plugin_custom_ops,
823                                                           param->aclModelOptionCfgParam.enable_custom_fusion_pattern,
824                                                           param->aclModelOptionCfgParam.disable_custom_fusion_pattern),
825      false},
826     {"MakeListPass", std::make_shared<opt::MakeListPass>(), true},
827     {"FlashAttentionFusion", std::make_shared<opt::FlashAttentionFusion>(param->aclModelOptionCfgParam.op_attrs_map),
828      false},
829     {"GroupNormSiluFusion", std::make_shared<opt::GroupNormSiluFusion>(), false},
830     {"GeGluV2Fusion", std::make_shared<opt::GeGluV2Fusion>(), false},
831     {"LayerNormV3Fusion", std::make_shared<opt::LayerNormV3Fusion>(), false},
832     {"FFNFusion", std::make_shared<opt::FFNFusion>(), false},
833     {"FuseAddAndLayernorm", std::make_shared<opt::FuseAddAndLayernorm>(), false},
834     {"AdjustMatmulPass", std::make_shared<opt::AdjustMatmulPass>(), false}};
835   for (const auto &pass_info : pass_infos) {
836     MS_CHECK_TRUE_RET(std::get<1>(pass_info) != nullptr, false);
837     PassStorage::StorePass(std::get<0>(pass_info), std::get<1>(pass_info), std::get<opt::kInputIndexTwo>(pass_info));
838   }
839   auto dump_graph_outer = std::make_shared<opt::DumpGraph>(param);
840   MS_CHECK_TRUE_MSG(dump_graph_outer != nullptr, false, "dumpGraph object is a nullptr.");
841   registry::PassRegistry("DumpGraph", dump_graph_outer);
842   return true;
843 }
844 
ClearBuiltinPass()845 void AnfTransform::ClearBuiltinPass() { PassStorage::ClearPass(); }
846 
Transform(const FuncGraphPtr & main_graph,const std::shared_ptr<ConverterPara> & param)847 STATUS AnfTransform::Transform(const FuncGraphPtr &main_graph, const std::shared_ptr<ConverterPara> &param) {
848   MS_CHECK_TRUE_MSG(main_graph != nullptr, RET_NULL_PTR, "Input func_graph is nullptr");
849   MS_CHECK_TRUE_MSG(param != nullptr, RET_NULL_PTR, "Input converter param is nullptr");
850   manager_ = Manage(main_graph, true);
851 
852   if (main_graph->has_attr(kOriginalFmkType)) {
853     auto val_ptr = main_graph->get_attr(kOriginalFmkType);
854     MS_CHECK_TRUE_MSG(val_ptr != nullptr, RET_NULL_PTR, "Val ptr is nullptr.");
855     param->fmk_type = static_cast<converter::FmkType>(GetValue<int32_t>(val_ptr));
856   }
857   if (main_graph->has_attr(kConverterInputShape)) {
858     auto val_ptr = main_graph->get_attr(kConverterInputShape);
859     MS_CHECK_TRUE_MSG(val_ptr != nullptr, RET_NULL_PTR, "Val ptr is nullptr.");
860     param->input_shape = TransStringToInputShapes(GetValue<std::string>(val_ptr));
861     for (auto &kv : param->input_shape) {
862       lite::ConverterInnerContext::GetInstance()->UpdateGraphInputTensorShape(kv.first, kv.second);
863     }
864   }
865   if (!StoreBuiltinPass(param)) {
866     MS_LOG(ERROR) << "store pass failed.";
867     return RET_ERROR;
868   }
869 
870   auto status = TransformFuncGraph(main_graph, param);
871   ClearBuiltinPass();
872   if (status != RET_OK) {
873     MS_LOG(ERROR) << "optimizer failed.";
874     return RET_NULL_PTR;
875   }
876 
877   return RET_OK;
878 }
879 }  // namespace mindspore::lite
880