1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/converter/anf_transform.h"
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <deque>
23 #include <map>
24 #include <tuple>
25 #include "nnacl/op_base.h"
26 #include "src/common/log_adapter.h"
27 #include "tools/converter/optimizer_manager.h"
28 #include "tools/optimizer/common/gllo_utils.h"
29 #include "tools/optimizer/common/pass_manager_extends.h"
30 #include "ir/primitive.h"
31 #include "tools/optimizer/fusion/add_activation_fusion.h"
32 #include "tools/optimizer/fusion/affine_activation_fusion.h"
33 #include "tools/optimizer/fusion/affine_fusion.h"
34 #include "tools/optimizer/fusion/conv_biasadd_fusion.h"
35 #include "tools/optimizer/fusion/conv_activation_fusion.h"
36 #include "tools/optimizer/fusion/adjust_matmul_pass.h"
37 #include "tools/optimizer/fusion/conv_tuple_activation_fusion.h"
38 #include "tools/optimizer/fusion/conv_scale_fusion.h"
39 #include "tools/optimizer/fusion/conv_bn_fusion.h"
40 #include "tools/optimizer/fusion/conv_tuplegetitem_fusion.h"
41 #include "tools/optimizer/const_fold/constant_folding_fusion.h"
42 #include "tools/optimizer/fusion/hard_swish_fusion.h"
43 #include "tools/optimizer/fusion/norm_fusion.h"
44 #include "tools/optimizer/fusion/prelu_fusion.h"
45 #include "tools/optimizer/fusion/batchmatmul_fusion.h"
46 #include "tools/optimizer/fusion/batchnorm_to_scale_fusion.h"
47 #include "tools/optimizer/fusion/sigmoid_mul_fusion.h"
48 #include "tools/optimizer/fusion/conv_conv_fusion.h"
49 #include "tools/optimizer/fusion/conv_pad_fusion.h"
50 #include "tools/optimizer/fusion/tflite_lstm_cell_fusion.h"
51 #include "tools/optimizer/fusion/tf_lstm_cell_fusion.h"
52 #include "tools/optimizer/fusion/tf_bidirection_gru_fusion.h"
53 #include "tools/optimizer/fusion/tensor_dot_fusion.h"
54 #include "tools/optimizer/fusion/multi_head_attention_fusion.h"
55 #include "tools/optimizer/fusion/encoder_layer_fusion.h"
56 #include "tools/optimizer/fusion/decoder_layer_fusion.h"
57 #include "tools/optimizer/fusion/glu_fusion.h"
58 #include "tools/optimizer/graph/unused_add_node_remove_pass.h"
59 #include "tools/optimizer/fusion/tflite_rel_pos_multi_head_attention_fusion.h"
60 #include "tools/optimizer/fusion/matmul_add_fusion.h"
61 #include "tools/optimizer/fusion/matmul_mul_fusion.h"
62 #include "tools/optimizer/fusion/mul_add_fusion.h"
63 #include "tools/optimizer/fusion/tf_gelu_fusion.h"
64 #include "tools/optimizer/fusion/onnx_gelu_fusion.h"
65 #include "tools/optimizer/fusion/squeeze_fusion.h"
66 #include "tools/optimizer/fusion/reshape_reshape_fusion.h"
67 #include "tools/optimizer/fusion/reshape_transpose_fusion.h"
68 #include "tools/optimizer/fusion/transpose_matmul_fusion.h"
69 #include "tools/optimizer/fusion/scale_activation_fusion.h"
70 #include "tools/optimizer/fusion/scale_scale_fusion.h"
71 #include "tools/optimizer/fusion/resize_fusion.h"
72 #include "tools/optimizer/fusion/fullconnected_fusion.h"
73 #include "tools/optimizer/fusion/fullconnected_add_fusion.h"
74 #include "tools/optimizer/fusion/add_concat_activation_fusion.h"
75 #include "tools/optimizer/fusion/matmul_activation_fusion.h"
76 #include "tools/optimizer/fusion/mul_activation_fusion.h"
77 #include "tools/optimizer/fusion/activation_fusion.h"
78 #include "tools/optimizer/fusion/reshape_reduce_fusion.h"
79 #include "tools/optimizer/fusion/add_layernorm_fusion.h"
80 #include "tools/optimizer/graph/add_tensor_array.h"
81 #include "tools/optimizer/graph/redundant_op_remove_pass.h"
82 #include "tools/optimizer/graph/clip_convert_activation_pass.h"
83 #include "tools/optimizer/graph/mul_constant_pass.h"
84 #include "tools/optimizer/graph/update_conv2d_param_pass.h"
85 #include "tools/optimizer/graph/infershape_pass.h"
86 #include "tools/optimizer/graph/slice_prepose_pass.h"
87 #include "tools/optimizer/graph/control_flow_pass.h"
88 #include "tools/optimizer/graph/reduce_same_act_pass.h"
89 #include "tools/optimizer/graph/split_one_pass.h"
90 #include "tools/optimizer/graph/decrease_transpose_algo.h"
91 #include "tools/optimizer/graph/special_node_postprocess.h"
92 #include "tools/optimizer/graph/specify_graph_input_format.h"
93 #include "tools/optimizer/graph/dump_graph.h"
94 #include "tools/optimizer/graph/eliminate_redundant_cast_pass.h"
95 #include "tools/converter/quantizer/quantization_optimizer.h"
96 #include "tools/optimizer/parallel/split_strategy.h"
97 #include "tools/optimizer/parallel/spliter.h"
98 #include "tools/optimizer/fisson/iter_node_outputs.h"
99 #include "tools/optimizer/fisson/node_out_shapes.h"
100 #include "tools/optimizer/parallel/parallel_pass.h"
101 #include "include/registry/pass_registry.h"
102 #include "tools/optimizer/fisson/multi_conv_split_pass.h"
103 #include "tools/optimizer/fusion/transpose_fusion.h"
104 #include "tools/optimizer/format/to_nchw_format.h"
105 #include "tools/optimizer/graph/int64_cast_int32_pass.h"
106 #include "tools/optimizer/graph/input_data_type_trans_pass.h"
107 #include "tools/optimizer/fusion/cast_fusion.h"
108 #include "tools/optimizer/format/to_nhwc_format.h"
109 #include "tools/optimizer/fusion/expanddims_reshape_fusion.h"
110 #include "tools/optimizer/fusion/reduce_same_op_in_horizon.h"
111 #include "tools/optimizer/fusion/reshape_shape_fusion.h"
112 #include "tools/optimizer/fusion/transpose_gather_fusion.h"
113 #ifndef ENABLE_CLOUD_FUSION_INFERENCE
114 #include "tools/converter/adapter/acl/acl_pass.h"
115 #endif
116 #include "src/common/log_util.h"
117 #include "src/common/string_utils.h"
118 #include "src/common/config_infos.h"
119 #include "tools/graph_kernel/converter/graph_kernel_optimization.h"
120 #include "tools/optimizer/fusion/groupnorm_fusion.h"
121 #include "tools/optimizer/fusion/mul_reduce_fusion.h"
122 #include "tools/optimizer/fusion/reshape_like_operator_ablation.h"
123 #include "tools/optimizer/fusion/concat_concat_fusion.h"
124 #include "tools/optimizer/fusion/strided_slice_fusion.h"
125 #include "tools/optimizer/fusion/reduce_stack_fusion.h"
126 #include "tools/optimizer/fusion/remove_transitivity_op.h"
127 #include "tools/converter/import/cast_op_adjust.h"
128 #include "tools/converter/adapter/acl/plugin/acl_pass_plugin.h"
129 #include "tools/converter/quantizer/quant_helper/qat_transform.h"
130 #include "tools/converter/parser/conv2d_transpose_input_adjust.h"
131 #include "tools/converter/parser/parser_utils.h"
132 #include "tools/converter/parser/unify_format.h"
133 #include "include/backend/optimizer/graph_optimizer.h"
134 #include "tools/optimizer/fusion/squeeze_expanddims_fusion.h"
135 #include "mindspore/core/ops/op_name.h"
136 #include "tools/common/string_util.h"
137 #include "src/common/common.h"
138 #include "tools/optimizer/graph/miniaturization_pass.h"
139 #include "tools/optimizer/graph/scalar_op_pass.h"
140 #include "tools/optimizer/fusion/tile_matmul_fusion.h"
141 #include "tools/optimizer/fusion/flash_attention_fusion_for_custom.h"
142 #include "tools/optimizer/fusion/gegluv2_fusion.h"
143 #include "tools/optimizer/fusion/ffn_fusion.h"
144 #include "tools/optimizer/graph/make_list_pass.h"
145 #include "tools/optimizer/fusion/flash_attention_fusion.h"
146 #include "tools/optimizer/fusion/groupnormsilu_fusion.h"
147 #include "tools/optimizer/fusion/adjust_resize_dims_pass.h"
148
149 using std::string;
150 namespace mindspore::lite {
151 namespace {
152 constexpr auto kOriginalFmkType = "original_fmk_type";
153 constexpr auto kConverterInputShape = "converter_input_shape";
154
TransInputShapesToString(const std::map<std::string,std::vector<int64_t>> & shapes)155 std::string TransInputShapesToString(const std::map<std::string, std::vector<int64_t>> &shapes) {
156 std::stringstream str_stream;
157 size_t shape_index = 0;
158 for (auto &item : shapes) {
159 str_stream << item.first << ":";
160 auto &shape = item.second;
161 for (size_t d = 0; d < shape.size(); d++) {
162 str_stream << shape[d];
163 if (d + 1 != shape.size()) {
164 str_stream << ",";
165 }
166 }
167 if (shape_index + 1 != shapes.size()) {
168 str_stream << ";";
169 }
170 shape_index++;
171 }
172 return str_stream.str();
173 }
174
TransStringToInputShapes(const std::string & shapes_str)175 std::map<std::string, std::vector<int64_t>> TransStringToInputShapes(const std::string &shapes_str) {
176 std::map<std::string, std::vector<int64_t>> shapes;
177 auto shapes_pairs = lite::SplitStringToVector(shapes_str, ';');
178 for (auto &kv_str : shapes_pairs) {
179 auto pos = kv_str.rfind(':');
180 if (pos == std::string::npos || pos + 1 == kv_str.size()) {
181 MS_LOG(ERROR) << "Invalid input shapes string: " << shapes_str;
182 return {};
183 }
184 auto name = kv_str.substr(0, pos);
185 auto shape_str = kv_str.substr(pos + 1);
186 auto shape_dims_str = lite::SplitStringToVector(shape_str, ',');
187 std::vector<int64_t> shape;
188 shape.reserve(shape_dims_str.size());
189 for (auto &dim_str : shape_dims_str) {
190 int dim = 0;
191 if (!lite::ConvertIntNum(dim_str, &dim)) {
192 MS_LOG(ERROR) << "Invalid input shapes string: " << shapes_str;
193 return {};
194 }
195 shape.push_back(dim);
196 }
197 shapes[name] = shape;
198 }
199 return shapes;
200 }
201 } // namespace
202
203 AnfTransform::AnfTransform() = default;
204
205 AnfTransform::~AnfTransform() = default;
206
MarkTrainInputOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)207 STATUS AnfTransform::MarkTrainInputOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
208 for (size_t i = 1; i < cnode->size(); i++) {
209 auto input_node = cnode->input(i);
210 if (!utils::isa<CNodePtr>(input_node)) {
211 continue;
212 }
213 auto input_cnode = utils::cast<CNodePtr>(input_node);
214 MS_CHECK_TRUE_RET(input_cnode != nullptr, RET_ERROR);
215 auto prim = GetValueNode<PrimitivePtr>(input_cnode->input(0));
216 if (prim == nullptr) {
217 MS_LOG(DEBUG) << "Primitive is nullptr.";
218 continue;
219 }
220 (void)prim->AddAttr("trainOp", MakeValue(true));
221 }
222 return RET_OK;
223 }
224
MarkTrainWeightSharingOp(const FuncGraphPtr & func_graph,const CNodePtr & cnode)225 STATUS AnfTransform::MarkTrainWeightSharingOp(const FuncGraphPtr &func_graph, const CNodePtr &cnode) {
226 auto node_list = TopoSort(func_graph->get_return());
227 for (auto &node : node_list) {
228 if (!utils::isa<CNodePtr>(node)) {
229 continue;
230 }
231 auto graph_cnode = utils::cast<CNodePtr>(node);
232 MS_CHECK_TRUE_RET(graph_cnode != nullptr, RET_ERROR);
233 auto graph_prim = GetValueNode<PrimitivePtr>(graph_cnode->input(0));
234 if (graph_prim == nullptr) {
235 MS_LOG(DEBUG) << "Primitive is nullptr.";
236 continue;
237 }
238 for (size_t i = 1; i < graph_cnode->size(); i++) {
239 for (size_t j = 1; j < cnode->size(); j++) {
240 if ((graph_cnode->input(i) == cnode->input(j)) && utils::isa<Parameter>(cnode->input(j))) {
241 (void)graph_prim->AddAttr("trainOp", MakeValue(true));
242 }
243 }
244 }
245 }
246 return RET_OK;
247 }
248
MarkTrainOp(const FuncGraphPtr & func_graph)249 STATUS AnfTransform::MarkTrainOp(const FuncGraphPtr &func_graph) {
250 auto node_list = TopoSort(func_graph->get_return());
251 for (auto &node : node_list) {
252 if (!utils::isa<CNodePtr>(node)) {
253 continue;
254 }
255 auto cnode = utils::cast<CNodePtr>(node);
256 MS_CHECK_TRUE_RET(cnode != nullptr, RET_ERROR);
257 auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
258 if (prim == nullptr) {
259 MS_LOG(DEBUG) << "Primitive is nullptr.";
260 continue;
261 }
262 if (opt::IsTrainOp(cnode)) {
263 (void)prim->AddAttr("trainOp", MakeValue(true));
264 auto status = MarkTrainInputOp(func_graph, cnode);
265 if (status != RET_OK) {
266 MS_LOG(ERROR) << "MarkTrainInputOp failed.";
267 return RET_ERROR;
268 }
269 status = MarkTrainWeightSharingOp(func_graph, cnode);
270 if (status != RET_OK) {
271 MS_LOG(ERROR) << "MarkTrainWeightSharingOp failed.";
272 return RET_ERROR;
273 }
274 }
275 }
276 return RET_OK;
277 }
InitFusions(const std::shared_ptr<ConverterPara> & param)278 std::vector<opt::PassPtr> InitFusions(const std::shared_ptr<ConverterPara> ¶m) {
279 // The training model only does the fusion of the inference part
280 // remove quantdtype when awaretraining
281 std::vector<opt::PassPtr> fusions{std::make_shared<opt::AddConcatActivationFusion>(),
282 std::make_shared<opt::HardSwishFusion>(),
283 std::make_shared<opt::PReluFusion>(),
284 std::make_shared<opt::SqueezeFusion>(),
285 std::make_shared<opt::TransposeFusion>(),
286 std::make_shared<opt::CastFusionPass>(),
287 std::make_shared<opt::ReshapeReshapeFusion>(),
288 std::make_shared<opt::ReshapeTransposeFusion>(),
289 std::make_shared<opt::ConvBiasaddFusion>(),
290 std::make_shared<opt::ConvBatchNormFusion>(param->fmk_type),
291 std::make_shared<opt::ConvScaleFusion>(param->fmk_type),
292 std::make_shared<opt::GroupNormFusion>(),
293 std::make_shared<opt::TfNormFusion>(),
294 std::make_shared<opt::OnnxLayerNormFusion>(),
295 std::make_shared<opt::OnnxLayerNormFusion2>(),
296 std::make_shared<opt::BatchMatMulFusion>(),
297 std::make_shared<opt::BatchNormToScaleFusion>(),
298 std::make_shared<opt::SigmoidMulFusion>(),
299 std::make_shared<opt::ActivationFusion>(),
300 std::make_shared<opt::ConvActivationFusion>(param),
301 std::make_shared<opt::ConvTupleGetItemFusion>(),
302 std::make_shared<opt::ConvTupleActivationFusion>(),
303 std::make_shared<opt::TfliteLstmCellFusion>(),
304 std::make_shared<opt::TfLstmCellFusion>(),
305 std::make_shared<opt::TfBidirectionGruFusion>(),
306 std::make_shared<opt::TfGeLUFusion>(),
307 std::make_shared<opt::OnnxGeLUFusion>(),
308 std::make_shared<opt::TfliteRelPosMultiHeadAttentionFusion>(),
309 std::make_shared<opt::GLUFusion>(),
310 std::make_shared<opt::ResizeFusion1>(),
311 std::make_shared<opt::ResizeFusion2>(),
312 std::make_shared<opt::ConstFoldPass>(param->fmk_type, param->train_model),
313 std::make_shared<opt::AffineFusion>(),
314 std::make_shared<opt::AffineActivationFusion>(),
315 std::make_shared<opt::ConvConvFusion>(),
316 std::make_shared<opt::ConvPadFusion>(),
317 std::make_shared<opt::MatMulAddFusion>(),
318 std::make_shared<opt::MatMulMulFusion>(),
319 std::make_shared<opt::TransposeMatMulFusion>(),
320 std::make_shared<opt::MulAddFusion>(),
321 std::make_shared<opt::ScaleActivationFusion>(),
322 std::make_shared<opt::ScaleScaleFusion>(),
323 std::make_shared<opt::FullConnectedFusion>(),
324 std::make_shared<opt::FullconnectedAddFusion>(),
325 std::make_shared<opt::TensorDotFusion>(),
326 std::make_shared<opt::MatMulActivationFusion>(param),
327 std::make_shared<opt::MulActivationFusion>(),
328 std::make_shared<opt::AddActivationFusion>(),
329 std::make_shared<opt::ExpandDimsReshapeFusion>(),
330 std::make_shared<opt::SqueezeExpandDimsFusion>(),
331 std::make_shared<opt::TileMatMulFusion>()};
332 if (param->optimize_transformer) {
333 fusions.push_back(std::make_shared<opt::MultiHeadAttentionFusion>());
334 fusions.push_back(std::make_shared<opt::EncoderLayerFusion>(true));
335 fusions.push_back(std::make_shared<opt::EncoderLayerFusion>(false));
336 fusions.push_back(std::make_shared<opt::DecoderLayerFusion>());
337 }
338 return fusions;
339 }
340
RunFusionPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)341 int AnfTransform::RunFusionPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
342 auto status = MarkTrainOp(old_graph);
343 if (status != RET_OK) {
344 MS_LOG(ERROR) << "MarkTrainOp failed.";
345 return RET_ERROR;
346 }
347 auto optimizer = std::make_shared<opt::GraphOptimizer>();
348 CHECK_NULL_RETURN(optimizer);
349 auto fusion_pm = std::make_shared<opt::LitePassManager>("anf fusion pass manager", false);
350 CHECK_NULL_RETURN(fusion_pm);
351
352 auto fusions = InitFusions(param);
353 for (size_t index = 0; index < fusions.size(); index++) {
354 auto pass_ptr = fusions.at(index);
355 MS_CHECK_TRUE_RET(pass_ptr != nullptr, RET_ERROR);
356 auto pass_name = pass_ptr->name();
357 if (param->fusion_blacklists.find(pass_name) != param->fusion_blacklists.end()) {
358 MS_LOG(INFO) << "Disable fusion: " << pass_name;
359 continue;
360 }
361 fusion_pm->AddPass(pass_ptr);
362 }
363 optimizer->AddPassManager(fusion_pm);
364 if (optimizer->Optimize(old_graph) == nullptr) {
365 MS_LOG(ERROR) << "run op fusion failed.";
366 return RET_ERROR;
367 }
368
369 // the following pass needs to check the return value.
370 fusions = {std::make_shared<opt::ReduceSameOpInHorizon>(param), std::make_shared<opt::ReshapeReduceFusion>(),
371 std::make_shared<opt::AblateReshapeLikeOp>(), std::make_shared<opt::MulReduceFusion>(),
372 std::make_shared<opt::ConcatConcatFusion>(), std::make_shared<opt::ReduceStackFusion>(),
373 std::make_shared<opt::RemoveTransitivityOp>(), std::make_shared<opt::StridedSliceFusion>(),
374 std::make_shared<opt::RemoveTransitivityOp>(), std::make_shared<opt::ReshapeShapeFusion>(),
375 std::make_shared<opt::TransposeGatherFusion>()};
376 for (auto &pass : fusions) {
377 MS_CHECK_TRUE_MSG(pass != nullptr, RET_ERROR, "pass is a nullptr.");
378 if (param->fusion_blacklists.find(pass->name()) != param->fusion_blacklists.end()) {
379 MS_LOG(INFO) << "Disable fusion: " << pass->name();
380 continue;
381 }
382 if (!pass->Run(old_graph)) {
383 MS_LOG(ERROR) << pass->name() << " running failed.";
384 return RET_ERROR;
385 }
386 }
387 return RET_OK;
388 }
389
RunParallelPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)390 int AnfTransform::RunParallelPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
391 MS_LOG(DEBUG) << "Run ParallelPass start";
392 if (param->train_model || param->parallel_split_config.parallel_split_type_ == SplitNo) {
393 return RET_OK;
394 }
395 if (param->parallel_split_config.parallel_split_type_ == SplitByUserRatio) {
396 auto optimizer = std::make_shared<opt::GraphOptimizer>();
397 CHECK_NULL_RETURN(optimizer);
398 auto graph_inputs = old_graph->get_inputs();
399 opt::SplitMode split_mode = opt::NoSplit;
400 for (const auto &graph_input : graph_inputs) {
401 if (utils::isa<Parameter>(graph_input)) {
402 auto input_parameter = dyn_cast<Parameter>(graph_input);
403 MSLITE_CHECK_PTR(input_parameter->Shape());
404 auto shape_ptr = input_parameter->Shape()->cast<abstract::ShapePtr>();
405 MSLITE_CHECK_PTR(shape_ptr);
406 auto batch = shape_ptr->shape().front();
407 if (batch > opt::kDefaultBatch) {
408 split_mode = opt::SplitN;
409 } else {
410 split_mode = opt::SplitH;
411 }
412 break;
413 }
414 }
415 // 1. deal with split strategy
416 std::unordered_map<std::string, opt::SplitStrategy> split_strategys = opt::ParserSplitStrategy(
417 param->parallel_split_config.parallel_compute_rates_, param->parallel_split_config.parallel_devices_, split_mode);
418 if (split_strategys.empty()) {
419 MS_LOG(WARNING) << "No valid split_strategy. Run convert without split";
420 return RET_OK;
421 }
422 opt::Spliter::GetInstance()->RecordGraphInfo(old_graph);
423 auto parallel_pm = std::make_shared<opt::LitePassManager>("anf parallel pass manager", true);
424 CHECK_NULL_RETURN(parallel_pm);
425 // 2. preceding parallel pass
426 parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
427 parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
428 std::set<int, opt::IntCompare> match_multi_numbers = opt::Spliter::GetInstance()->graph_match_multi_numbers();
429 int max_match_number = *match_multi_numbers.begin();
430 // we do not deal with single conv node
431 for (int match_number = max_match_number; match_number > opt::kDefaultBatch; --match_number) {
432 // 3. multi_conv parallel pass
433 parallel_pm->AddPass(std::make_shared<opt::MultiConvSplitPass>(split_strategys, param->fmk_type, match_number));
434 parallel_pm->AddPass(std::make_shared<opt::IterNodeOutputs>());
435 parallel_pm->AddPass(std::make_shared<opt::NodeOutShapes>());
436 }
437 optimizer->AddPassManager(parallel_pm);
438 if (optimizer->Optimize(old_graph) == nullptr) {
439 MS_LOG(ERROR) << "run const fold failed.";
440 return RET_ERROR;
441 }
442 }
443 MS_LOG(DEBUG) << "Run ParallelPass end";
444 return RET_OK;
445 }
446
RunGraphPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)447 int AnfTransform::RunGraphPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
448 auto optimizer = std::make_shared<opt::GraphOptimizer>();
449 CHECK_NULL_RETURN(optimizer);
450 auto graph_pm = std::make_shared<opt::LitePassManager>("anf graph pass manager", true);
451 CHECK_NULL_RETURN(graph_pm);
452 if (param->fmk_type == converter::kFmkTypeTflite || param->fmk_type == converter::kFmkTypeTf ||
453 param->fmk_type == converter::kFmkTypeOnnx) {
454 graph_pm->AddPass(std::make_shared<opt::ControlFlowPass>());
455 }
456 auto slice_prepose_pass = std::make_shared<opt::SlicePreposePass>();
457 CHECK_NULL_RETURN(slice_prepose_pass);
458 slice_prepose_pass->SetFmkType(param->fmk_type);
459 graph_pm->AddPass(slice_prepose_pass);
460 optimizer->AddPassManager(graph_pm);
461 if (optimizer->Optimize(old_graph) == nullptr) {
462 MS_LOG(ERROR) << "run graph pass failed.";
463 return RET_ERROR;
464 }
465 return RET_OK;
466 }
467
RunConvertPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)468 int AnfTransform::RunConvertPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
469 if (param->device.find("Ascend") != std::string::npos) {
470 auto acl_pass_ptr = opt::AclPassPlugin::CreateAclPass(param);
471 if (acl_pass_ptr == nullptr) {
472 MS_LOG(ERROR) << "Failed to create acl pass";
473 return RET_ERROR;
474 }
475 if (!acl_pass_ptr->Run(old_graph)) {
476 MS_LOG(ERROR) << "Acl pass failed.";
477 return RET_ERROR;
478 }
479 }
480 // adjust for conv2d_transpose
481 if (!(param->no_fusion && param->save_type == kMindIR)) {
482 std::set<FuncGraphPtr> all_func_graphs = {};
483 GetAllFuncGraph(old_graph, &all_func_graphs);
484 auto conv2d_transpose_adjust = std::make_shared<Conv2DTransposeInputAdjust>();
485 MS_CHECK_TRUE_MSG(conv2d_transpose_adjust != nullptr, RET_NULL_PTR, "conv2d_transpose_adjust is nullptr.");
486 for (auto sub_graph : all_func_graphs) {
487 if (!conv2d_transpose_adjust->Run(old_graph)) {
488 MS_LOG(ERROR) << "adjust conv2d_transpose failed";
489 return RET_ERROR;
490 }
491 }
492 }
493 auto optimizer = std::make_shared<opt::GraphOptimizer>();
494 CHECK_NULL_RETURN(optimizer);
495 auto convert_pm = std::make_shared<opt::LitePassManager>("anf graph convert pass manager", true);
496 CHECK_NULL_RETURN(convert_pm);
497 convert_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
498 convert_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
499 convert_pm->AddPass(std::make_shared<opt::CastOpAdjust>());
500 convert_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
501 optimizer->AddPassManager(convert_pm);
502 if (optimizer->Optimize(old_graph) == nullptr) {
503 MS_LOG(ERROR) << "run graph convert pass failed.";
504 return RET_ERROR;
505 }
506 return RET_OK;
507 }
508
RunConstFoldPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)509 int AnfTransform::RunConstFoldPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
510 auto optimizer = std::make_shared<opt::GraphOptimizer>();
511 auto const_fold_pm = std::make_shared<opt::LitePassManager>("const fold fusion pass manager", false);
512 CHECK_NULL_RETURN(optimizer);
513 CHECK_NULL_RETURN(const_fold_pm);
514 if (param->train_model) {
515 const_fold_pm->AddPass(std::make_shared<opt::MiniaturizationPass>());
516 }
517 const_fold_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
518 if (!param->train_model) {
519 const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(param->fmk_type, param->train_model));
520 }
521 const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
522 const_fold_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>(), param->fusion_blacklists);
523 optimizer->AddPassManager(const_fold_pm);
524 if (optimizer->Optimize(old_graph) == nullptr) {
525 MS_LOG(ERROR) << "run const fold failed.";
526 return RET_ERROR;
527 }
528 return RET_OK;
529 }
530
RunInt64CastInt32Pass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)531 int AnfTransform::RunInt64CastInt32Pass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
532 auto optimizer = std::make_shared<opt::GraphOptimizer>();
533 auto int64_cast_int32_pm = std::make_shared<opt::LitePassManager>("int64 cast to int32 pass manager", false);
534 CHECK_NULL_RETURN(optimizer);
535 CHECK_NULL_RETURN(int64_cast_int32_pm);
536 int64_cast_int32_pm->AddPass(std::make_shared<opt::InferShapePass>(param->fmk_type, param->train_model));
537 int64_cast_int32_pm->AddPass(std::make_shared<opt::Int64CastInt32Pass>());
538 int64_cast_int32_pm->AddPass(std::make_shared<opt::CastFusionPass>());
539
540 optimizer->AddPassManager(int64_cast_int32_pm);
541 if (optimizer->Optimize(old_graph) == nullptr) {
542 MS_LOG(ERROR) << "run const fold failed.";
543 return RET_ERROR;
544 }
545 return RET_OK;
546 }
547
RunDecreaseTransposePass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)548 int RunDecreaseTransposePass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
549 MS_ASSERT(old_graph != nullptr && param != nullptr);
550 auto pass = std::make_shared<opt::DecreaseTransposeAlgo>(param->fmk_type, param->train_model, false);
551 MS_CHECK_TRUE_RET(pass != nullptr, RET_ERROR);
552 if (!pass->Run(old_graph)) {
553 MS_LOG(ERROR) << "Run DecreaseTransposeAlgo pass failed";
554 return RET_ERROR;
555 }
556
557 auto optimizer = std::make_shared<opt::GraphOptimizer>();
558 auto decrease_trans_pm = std::make_shared<opt::LitePassManager>("decrease transpose fusion pass manager", false);
559 CHECK_NULL_RETURN(optimizer);
560 CHECK_NULL_RETURN(decrease_trans_pm);
561 std::vector<opt::PassPtr> fusions = {std::make_shared<opt::ReshapeTransposeFusion>(),
562 std::make_shared<opt::TransposeFusion>()};
563 (void)std::for_each(fusions.begin(), fusions.end(), [&decrease_trans_pm, ¶m](opt::PassPtr fusion) {
564 if (fusion != nullptr && param->fusion_blacklists.find(fusion->name()) == param->fusion_blacklists.end()) {
565 decrease_trans_pm->AddPass(fusion);
566 }
567 });
568 optimizer->AddPassManager(decrease_trans_pm);
569 if (optimizer->Optimize(old_graph) == nullptr) {
570 MS_LOG(ERROR) << "run decrease transpose failed.";
571 return RET_ERROR;
572 }
573 return RET_OK;
574 }
575
CheckExternalExtension(const std::shared_ptr<ConverterPara> & param)576 bool AnfTransform::CheckExternalExtension(const std::shared_ptr<ConverterPara> ¶m) {
577 return (!param->plugins_path.empty() && param->commonQuantParam.quant_type != quant::QUANT_NONE);
578 }
579
DoQuantize(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)580 int AnfTransform::DoQuantize(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
581 quant::QuantizationOptimizer quantization_optimizer(param);
582 auto ret = quantization_optimizer.Run(old_graph);
583 if (ret != RET_OK) {
584 MS_LOG(ERROR) << "Post training quantization failed.";
585 return ret;
586 }
587 return RET_OK;
588 }
589
DoFormatForMindIR(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)590 int AnfTransform::DoFormatForMindIR(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
591 if (param->save_type != kMindIR) {
592 return RET_OK;
593 }
594 if (param->no_fusion || param->device.find("Ascend") == std::string::npos) {
595 MS_LOG(INFO) << "export MindIR, run pass ToNCHWFormat";
596 if (!RunOptimizerPass(old_graph, {"ToNCHWFormat", "DecreaseTransposeAlgo"})) {
597 MS_LOG(ERROR) << "Run ToNCHWFormat pass failed";
598 return RET_ERROR;
599 }
600 }
601 old_graph->set_attr(kOriginalFmkType, MakeValue(static_cast<int32_t>(param->fmk_type)));
602
603 return RET_OK;
604 }
605
RunFormatTrans(const FuncGraphPtr & old_graph)606 int AnfTransform::RunFormatTrans(const FuncGraphPtr &old_graph) {
607 auto value = old_graph->get_attr(ops::kFormat);
608 if (value != nullptr && GetValue<int64_t>(value) == mindspore::NHWC) {
609 return RET_OK;
610 }
611 if (!RunOptimizerPass(old_graph, {"ToNHWCFormat", "DecreaseTransposeAlgo"})) {
612 MS_LOG(ERROR) << "Run ToNHWCFormat pass failed";
613 return RET_ERROR;
614 }
615 return RET_OK;
616 }
617
RunEliminateRedundantPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)618 bool RunEliminateRedundantPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
619 if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
620 MS_LOG(WARNING) << "Run infershape opt pass failed.";
621 } else if (!RunOptimizerPass(old_graph, {"DecreaseTransposeAlgo"})) {
622 MS_LOG(ERROR) << "Run transpose opt pass failed.";
623 return false;
624 }
625
626 auto optimizer = std::make_shared<opt::GraphOptimizer>();
627 MS_CHECK_TRUE_RET(optimizer != nullptr, false);
628 auto eliminate_pm = std::make_shared<opt::LitePassManager>("anf graph eliminate redundant pass manager", true);
629 MS_CHECK_TRUE_RET(eliminate_pm != nullptr, false);
630 eliminate_pm->AddPass(std::make_shared<opt::RemoveRedundantOpPass>(param->train_model));
631 eliminate_pm->AddPass(std::make_shared<opt::EliminateRedundantCastPass>(param->fmk_type, param->train_model));
632 eliminate_pm->AddPass(std::make_shared<opt::ReduceSameActPass>());
633 eliminate_pm->AddPass(std::make_shared<opt::SplitOnePass>());
634 eliminate_pm->AddPass(std::make_shared<opt::MulConstantPass>());
635 optimizer->AddPassManager(eliminate_pm);
636 if (optimizer->Optimize(old_graph) == nullptr) {
637 MS_LOG(ERROR) << "run graph convert pass failed.";
638 return false;
639 }
640 return true;
641 }
642
ProcOnlineTransform(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)643 STATUS AnfTransform::ProcOnlineTransform(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
644 if (!RunOptimizerPass(old_graph, {"RemoveRedundantOpPass", "InferShapePass", "ConstFoldPass"})) {
645 MS_LOG(WARNING) << "Run infershape opt pass failed.";
646 }
647 auto status = DoFormatForMindIR(old_graph, param);
648 if (status != RET_OK) {
649 MS_LOG(ERROR) << "Do format for mindir failed.";
650 return lite::RET_ERROR;
651 }
652 if (!param->input_shape.empty()) {
653 auto graph_inputs = old_graph->get_inputs();
654 std::map<std::string, std::vector<int64_t>> input_shape;
655 for (auto &input : graph_inputs) {
656 auto abstract = input->abstract();
657 if (abstract) {
658 input_shape[abstract->name()] = opt::GetAnfNodeOutputShape(input, 0);
659 }
660 }
661 old_graph->set_attr(kConverterInputShape, MakeValue(TransInputShapesToString(input_shape)));
662 }
663 return lite::RET_OK;
664 }
665
RunPass(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)666 int AnfTransform::RunPass(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
667 auto status = RunConvertPass(old_graph, param);
668 if (status != RET_OK) {
669 MS_LOG(ERROR) << "Run convert pass failed.";
670 return RET_ERROR;
671 }
672 if (!RunExternalPass(old_graph, registry::POSITION_BEGIN)) {
673 MS_LOG(ERROR) << "Run external pass failed, place is BEGIN";
674 return RET_ERROR;
675 }
676
677 status = RunInt64CastInt32Pass(old_graph, param);
678 if (status != RET_OK) {
679 MS_LOG(ERROR) << "RunInt64CastInt32Pass failed.";
680 return RET_ERROR;
681 }
682 status = RunConstFoldPass(old_graph, param);
683 if (status != RET_OK) {
684 MS_LOG(ERROR) << "Run const fold pass failed.";
685 return RET_ERROR;
686 }
687
688 if (!RunEliminateRedundantPass(old_graph, param)) {
689 MS_LOG(ERROR) << "Run elimination of redundant pass failed.";
690 return RET_ERROR;
691 }
692
693 if (!param->no_fusion) {
694 status = RunFusionPass(old_graph, param);
695 if (status != RET_OK) {
696 MS_LOG(ERROR) << "Run fusion pass failed.";
697 return RET_ERROR;
698 }
699 }
700
701 if (!RunExternalPass(old_graph, registry::POSITION_END)) {
702 MS_LOG(ERROR) << "Run external pass failed, place is END";
703 return RET_ERROR;
704 }
705
706 if (!RunOptimizerPass(old_graph, {"InferShapePass"})) {
707 MS_LOG(WARNING) << "Run infershape opt pass failed.";
708 status = RunOptimizerPass(old_graph, {"SpecialNodePostProcess"}) ? RET_OK : RET_ERROR;
709 } else {
710 status =
711 RunOptimizerPass(old_graph, {"SpecialNodePostProcess"}) ? RunDecreaseTransposePass(old_graph, param) : RET_ERROR;
712 }
713 if (status != RET_OK) {
714 MS_LOG(ERROR) << "Run transpose opt pass failed.";
715 return RET_ERROR;
716 }
717
718 if (CheckExternalExtension(param)) {
719 MS_LOG(ERROR) << "Unsupported external extension with quantization.";
720 return RET_ERROR;
721 }
722 // QATTransform will infer all subgraphs and should be executed before ControlFlowPass.
723 // After ControlFlowPass, there will be some ops that cannot be handled in the main graph, therefore, The
724 // InferShapePass cannot be executed.
725 auto qat_transform = quant::QATTransform(old_graph, param);
726 status = qat_transform.Transform();
727 if (status != RET_OK) {
728 MS_LOG(ERROR) << "Do QATTransform failed.";
729 return RET_ERROR;
730 }
731
732 status = RunGraphPass(old_graph, param);
733 if (status != RET_OK) {
734 MS_LOG(ERROR) << "Run convert pass failed.";
735 return RET_ERROR;
736 }
737
738 status = RunParallelPass(old_graph, param);
739 if (status != RET_OK) {
740 MS_LOG(ERROR) << "Run convert pass failed.";
741 return RET_ERROR;
742 }
743 return RET_OK;
744 }
745
TransformFuncGraph(const FuncGraphPtr & old_graph,const std::shared_ptr<ConverterPara> & param)746 STATUS AnfTransform::TransformFuncGraph(const FuncGraphPtr &old_graph, const std::shared_ptr<ConverterPara> ¶m) {
747 MS_ASSERT(old_graph != nullptr);
748 MS_ASSERT(param != nullptr);
749 if (param->no_fusion && param->save_type == kMindIR) { // converter, online
750 if (ProcOnlineTransform(old_graph, param) != lite::RET_OK) {
751 MS_LOG(ERROR) << "Proc online transform failed.";
752 return RET_ERROR;
753 }
754 auto status = DoQuantize(old_graph, param);
755 if (status != RET_OK) {
756 MS_LOG(ERROR) << "Do Quantize failed.";
757 return RET_ERROR;
758 }
759 return RET_OK;
760 }
761 auto value = old_graph->get_attr(kIsOptimized);
762 if (param->is_runtime_converter) { // load online
763 if (value != nullptr) { // other models converted to MindIR
764 auto status = RunFormatTrans(old_graph);
765 if (status != RET_OK) {
766 MS_LOG(ERROR) << "Run format trans failed";
767 return status;
768 }
769 }
770 }
771
772 if (RunPass(old_graph, param) != RET_OK) {
773 MS_LOG(ERROR) << "Proc online transform failed.";
774 return RET_ERROR;
775 }
776
777 auto status = DoQuantize(old_graph, param);
778 if (status != RET_OK) {
779 MS_LOG(ERROR) << "Do Quantize failed.";
780 return RET_ERROR;
781 }
782
783 #ifdef MSLITE_ENABLE_GRAPH_KERNEL
784 if (param->device.find("Ascend") == std::string::npos) {
785 if (GraphKernelOptimize(old_graph, param) != RET_OK) {
786 MS_LOG(ERROR) << "Do graphkernel optimization failed.";
787 return RET_ERROR;
788 }
789 }
790 #endif
791
792 status = DoFormatForMindIR(old_graph, param);
793 if (status != RET_OK) {
794 return RET_ERROR;
795 }
796 return RET_OK;
797 }
798
StoreBuiltinPass(const std::shared_ptr<ConverterPara> & param)799 bool AnfTransform::StoreBuiltinPass(const std::shared_ptr<ConverterPara> ¶m) {
800 if (param == nullptr) {
801 MS_LOG(ERROR) << "param is nullptr";
802 return false;
803 }
804 auto fmk = param->fmk_type;
805 auto is_train = param->train_model;
806
807 // pass_name, pass and boolean value to indicate whether can be called by external extension,
808 std::vector<std::tuple<std::string, opt::PassPtr, bool>> pass_infos = {
809 {"DumpGraph", std::make_shared<opt::DumpGraph>(param), true},
810 {"RemoveRedundantOpPass", std::make_shared<opt::RemoveRedundantOpPass>(param->train_model), false},
811 {"ToNCHWFormat", std::make_shared<opt::ToNCHWFormat>(fmk, is_train, param->save_type), true},
812 {"ToNHWCFormat", std::make_shared<opt::ToNHWCFormat>(fmk, is_train, param->save_type), true},
813 {"ConstFoldPass", std::make_shared<opt::ConstFoldPass>(fmk, is_train), true},
814 {"InferShapePass", std::make_shared<opt::InferShapePass>(fmk, is_train), false},
815 {"DeleteRedundantTranspose", std::make_shared<opt::DeleteRedundantTranspose>(), false},
816 {"SpecialNodePostProcess", std::make_shared<opt::SpecialNodePostProcess>(), false},
817 {"DecreaseTransposeAlgo", std::make_shared<opt::DecreaseTransposeAlgo>(fmk, is_train), true},
818 {"RemoveUnusedAddNodePass", std::make_shared<opt::RemoveUnusedAddNodePass>(), false},
819 {"AdjustResizeDimsPass", std::make_shared<opt::AdjustResizeDimsPass>(), false},
820 {"ScalarOpPass", std::make_shared<opt::ScalarOpPass>(), true},
821 {"FlashAttentionFusionForCustom",
822 std::make_shared<opt::FlashAttentionFusionForCustom>(param->aclModelOptionCfgParam.plugin_custom_ops,
823 param->aclModelOptionCfgParam.enable_custom_fusion_pattern,
824 param->aclModelOptionCfgParam.disable_custom_fusion_pattern),
825 false},
826 {"MakeListPass", std::make_shared<opt::MakeListPass>(), true},
827 {"FlashAttentionFusion", std::make_shared<opt::FlashAttentionFusion>(param->aclModelOptionCfgParam.op_attrs_map),
828 false},
829 {"GroupNormSiluFusion", std::make_shared<opt::GroupNormSiluFusion>(), false},
830 {"GeGluV2Fusion", std::make_shared<opt::GeGluV2Fusion>(), false},
831 {"LayerNormV3Fusion", std::make_shared<opt::LayerNormV3Fusion>(), false},
832 {"FFNFusion", std::make_shared<opt::FFNFusion>(), false},
833 {"FuseAddAndLayernorm", std::make_shared<opt::FuseAddAndLayernorm>(), false},
834 {"AdjustMatmulPass", std::make_shared<opt::AdjustMatmulPass>(), false}};
835 for (const auto &pass_info : pass_infos) {
836 MS_CHECK_TRUE_RET(std::get<1>(pass_info) != nullptr, false);
837 PassStorage::StorePass(std::get<0>(pass_info), std::get<1>(pass_info), std::get<opt::kInputIndexTwo>(pass_info));
838 }
839 auto dump_graph_outer = std::make_shared<opt::DumpGraph>(param);
840 MS_CHECK_TRUE_MSG(dump_graph_outer != nullptr, false, "dumpGraph object is a nullptr.");
841 registry::PassRegistry("DumpGraph", dump_graph_outer);
842 return true;
843 }
844
ClearBuiltinPass()845 void AnfTransform::ClearBuiltinPass() { PassStorage::ClearPass(); }
846
Transform(const FuncGraphPtr & main_graph,const std::shared_ptr<ConverterPara> & param)847 STATUS AnfTransform::Transform(const FuncGraphPtr &main_graph, const std::shared_ptr<ConverterPara> ¶m) {
848 MS_CHECK_TRUE_MSG(main_graph != nullptr, RET_NULL_PTR, "Input func_graph is nullptr");
849 MS_CHECK_TRUE_MSG(param != nullptr, RET_NULL_PTR, "Input converter param is nullptr");
850 manager_ = Manage(main_graph, true);
851
852 if (main_graph->has_attr(kOriginalFmkType)) {
853 auto val_ptr = main_graph->get_attr(kOriginalFmkType);
854 MS_CHECK_TRUE_MSG(val_ptr != nullptr, RET_NULL_PTR, "Val ptr is nullptr.");
855 param->fmk_type = static_cast<converter::FmkType>(GetValue<int32_t>(val_ptr));
856 }
857 if (main_graph->has_attr(kConverterInputShape)) {
858 auto val_ptr = main_graph->get_attr(kConverterInputShape);
859 MS_CHECK_TRUE_MSG(val_ptr != nullptr, RET_NULL_PTR, "Val ptr is nullptr.");
860 param->input_shape = TransStringToInputShapes(GetValue<std::string>(val_ptr));
861 for (auto &kv : param->input_shape) {
862 lite::ConverterInnerContext::GetInstance()->UpdateGraphInputTensorShape(kv.first, kv.second);
863 }
864 }
865 if (!StoreBuiltinPass(param)) {
866 MS_LOG(ERROR) << "store pass failed.";
867 return RET_ERROR;
868 }
869
870 auto status = TransformFuncGraph(main_graph, param);
871 ClearBuiltinPass();
872 if (status != RET_OK) {
873 MS_LOG(ERROR) << "optimizer failed.";
874 return RET_NULL_PTR;
875 }
876
877 return RET_OK;
878 }
879 } // namespace mindspore::lite
880