1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/ps/pass.h"
18
19 #include <memory>
20 #include <vector>
21 #include <string>
22 #include <algorithm>
23
24 #include "mindspore/core/ops/other_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "utils/hash_map.h"
27 #include "ir/func_graph_cloner.h"
28 #include "pipeline/jit/ps/parse/parse_base.h"
29 #include "pipeline/jit/ps/resource.h"
30 #include "pipeline/jit/ps/validator.h"
31 #include "pipeline/jit/ps/remove_value_node_dup.h"
32 #include "frontend/optimizer/opt.h"
33 #include "frontend/optimizer/optimizer.h"
34 #include "frontend/optimizer/cse_pass.h"
35 #include "frontend/optimizer/fallback_rewriter.h"
36 #include "frontend/optimizer/irpass.h"
37 #include "frontend/optimizer/graph_transform.h"
38 #include "frontend/optimizer/auto_monad_eliminate.h"
39 #include "include/common/fallback.h"
40 #include "include/common/utils/parallel_context.h"
41 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
42 #include "frontend/parallel/step_parallel.h"
43 #include "frontend/parallel/step_auto_parallel.h"
44 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
45 #include "frontend/parallel/pipeline_transformer/pipeline_scheduler.h"
46 #include "frontend/parallel/pipeline_transformer/pipeline_interleave.h"
47 #include "frontend/parallel/pipeline_transformer/gpipe_interleave_scheduler.h"
48 #include "frontend/parallel/pass/merge_comm.h"
49 #include "frontend/parallel/cache_embedding/cache_embedding.h"
50 #include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h"
51 #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
52 #include "frontend/parallel/shard/shard.h"
53 #include "frontend/parallel/pass/optimize_parallel_allgather_comm.h"
54 #include "frontend/parallel/pass/label_micro_interleaved_index.h"
55 #include "frontend/parallel/pass/label_fine_grained_interleaved_index.h"
56 #include "frontend/parallel/pass/reorder_send_recv_between_fp_bp.h"
57 #include "frontend/parallel/pass/micro_interleaved_order_control.h"
58 #include "frontend/parallel/pass/full_micro_interleaved_order_control.h"
59 #include "frontend/parallel/pass/overlap_recompute_allgather_and_flashattention_grad.h"
60 #include "frontend/parallel/pass/assign_add_opt.h"
61 #include "frontend/parallel/pass/float32_redistribution.h"
62 #include "frontend/parallel/pass/swap_dp_allreduce_reducescatter.h"
63 #include "frontend/parallel/pass/merge_cast_opt.h"
64 #include "frontend/parallel/pass/remove_cast_before_assign_add.h"
65 #include "frontend/parallel/pass/bias_add_comm_swap.h"
66 #include "frontend/parallel/pass/matmul_add_comm_reduction.h"
67 #include "frontend/parallel/pass/comp_comm_scheduling.h"
68 #include "frontend/parallel/pass/overlap_opt_shard_in_pipeline.h"
69 #include "frontend/parallel/pass/slice_activation_in_cell_share_recompute.h"
70 #include "frontend/parallel/pass/handle_group_info.h"
71 #include "frontend/parallel/pass/overlap_recompute_and_grad_model_parallel.h"
72 #include "frontend/parallel/pass/overlap_gradmatmul_and_gradallreduce.h"
73 #include "frontend/parallel/pass/begin_end_overlap_inline.h"
74 #include "frontend/parallel/pass/split_matmul_comm_elementwise_fp.h"
75 #include "frontend/parallel/pass/split_layernorm_comm_fp.h"
76 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
77 #include "frontend/parallel/pass/overlap_grad_comm.h"
78 #include "frontend/optimizer/recompute.h"
79 #include "frontend/optimizer/irpass/recompute.h"
80 #include "frontend/optimizer/slice_activation_in_recompute.h"
81 #include "frontend/optimizer/grouped_pairwise_exchange_alltoall.h"
82 #include "frontend/optimizer/comm_op_attrs.h"
83 #include "frontend/optimizer/process_send_recv_for_ge.h"
84 #include "frontend/optimizer/environ_conversion.h"
85 #include "frontend/optimizer/comm_op_reuse_tag.h"
86 #include "frontend/optimizer/py_interpret_to_execute.h"
87 #include "frontend/optimizer/flash_sp.h"
88 #include "utils/log_adapter.h"
89 #include "utils/compile_config.h"
90 #include "pipeline/jit/ps/pipeline_split.h"
91 #include "pipeline/pynative/pynative_execute.h"
92 #include "pipeline/jit/ps/static_analysis/auto_monad.h"
93 #include "frontend/optimizer/irpass/branch_culling.h"
94 #include "frontend/optimizer/irpass/meta_fg_eliminate.h"
95 #include "frontend/optimizer/irpass/gradient_eliminate.h"
96 #include "frontend/optimizer/irpass/shard_eliminate.h"
97 #include "frontend/optimizer/irpass/taylor_eliminate.h"
98 #include "frontend/optimizer/irpass/parameter_eliminate.h"
99 #include "frontend/optimizer/irpass/updatestate_eliminate.h"
100 #include "frontend/optimizer/irpass/expand_dump_flag.h"
101 #include "frontend/optimizer/irpass/symbol_engine_optimizer.h"
102 #include "frontend/optimizer/irpass/add_forward_monad_depend.h"
103 #if defined(__linux__) && defined(WITH_BACKEND)
104 #include "include/backend/distributed/ps/util.h"
105 #include "include/backend/distributed/ps/ps_context.h"
106 #endif
107
108 namespace mindspore {
109 namespace pipeline {
110 using OptPassGroupMap = opt::OptPassGroupMap;
111 using Optimizer = opt::Optimizer;
112 using CompileGraphs = compile::CompileGraphs;
113 using abstract::AnalysisResult;
114 using mindspore::abstract::AnalysisContextPtr;
115 using mindspore::validator::Validate;
UpdateArgsSpec(const FuncGraphPtr & func_graph,const ResourcePtr & resource)116 void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource) {
117 MS_EXCEPTION_IF_NULL(func_graph);
118 MS_EXCEPTION_IF_NULL(resource);
119 abstract::AbstractBasePtrList args_abs;
120 const auto ¶meters = func_graph->parameters();
121 args_abs.reserve(parameters.size());
122 (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs),
123 [](const AnfNodePtr &p) { return p->abstract(); });
124 resource->set_args_abs(args_abs);
125 }
126
PyInterpretToExecutePass(const ResourcePtr & resource)127 bool PyInterpretToExecutePass(const ResourcePtr &resource) {
128 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
129 if (!allow_fallback_runtime) {
130 return true;
131 }
132 MS_EXCEPTION_IF_NULL(resource);
133 FuncGraphPtr func_graph = resource->func_graph();
134 MS_EXCEPTION_IF_NULL(func_graph);
135 (void)opt::PyInterpretToExecute(resource);
136 UpdateArgsSpec(func_graph, resource);
137 return true;
138 }
139
RewriterBeforeOptAPass(const ResourcePtr & resource)140 bool RewriterBeforeOptAPass(const ResourcePtr &resource) {
141 MS_EXCEPTION_IF_NULL(resource);
142 FuncGraphPtr func_graph = resource->func_graph();
143 MS_EXCEPTION_IF_NULL(func_graph);
144 (void)opt::RewriterBeforeOptA(func_graph, resource->manager());
145 UpdateArgsSpec(func_graph, resource);
146 return true;
147 }
148
TransformTopGraphPass(const ResourcePtr & resource)149 bool TransformTopGraphPass(const ResourcePtr &resource) {
150 MS_EXCEPTION_IF_NULL(resource);
151 if (resource->func_graph() == nullptr) {
152 MS_LOG(INTERNAL_EXCEPTION) << "Transform top graph error.";
153 }
154 FuncGraphPtr func_graph = resource->func_graph();
155 if (opt::FuncGraphHasSequenceInput(func_graph)) {
156 opt::GraphSequenceParamTransform graph_trans;
157 func_graph = graph_trans(func_graph, resource->manager());
158 resource->set_func_graph(func_graph);
159 AbstractBasePtrList abs_spec_list;
160 auto ¶ms = func_graph->parameters();
161 (void)std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list),
162 [](const AnfNodePtr &node) { return node->abstract(); });
163 resource->set_args_abs(abs_spec_list);
164 }
165 return true;
166 }
167
RewriterAfterOptAPass(const ResourcePtr & resource)168 bool RewriterAfterOptAPass(const ResourcePtr &resource) {
169 MS_EXCEPTION_IF_NULL(resource);
170 FuncGraphPtr func_graph = resource->func_graph();
171 MS_EXCEPTION_IF_NULL(func_graph);
172 (void)opt::RewriterAfterOptA(func_graph, resource);
173 UpdateArgsSpec(func_graph, resource);
174 return true;
175 }
176
ConvertAfterRewriterPass(const ResourcePtr & resource)177 bool ConvertAfterRewriterPass(const ResourcePtr &resource) {
178 MS_EXCEPTION_IF_NULL(resource);
179 FuncGraphPtr func_graph = resource->func_graph();
180 MS_EXCEPTION_IF_NULL(func_graph);
181 (void)opt::ConvertAfterRewriter(func_graph, resource);
182 UpdateArgsSpec(func_graph, resource);
183 return true;
184 }
185
OrderPyExecuteAfterRewriterPass(const ResourcePtr & resource)186 bool OrderPyExecuteAfterRewriterPass(const ResourcePtr &resource) {
187 MS_EXCEPTION_IF_NULL(resource);
188 FuncGraphPtr func_graph = resource->func_graph();
189 MS_EXCEPTION_IF_NULL(func_graph);
190 (void)opt::OrderPyExecuteAfterRewriter(func_graph, resource);
191 UpdateArgsSpec(func_graph, resource);
192 return true;
193 }
194
PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib & irpass,const ResourcePtr & resource)195 FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource) {
196 MS_EXCEPTION_IF_NULL(resource);
197 MS_EXCEPTION_IF_NULL(resource->func_graph());
198 opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
199 irpass.pynative_eliminate_,
200 });
201
202 opt::OptPassConfig switch_simplify = opt::OptPassConfig({
203 irpass.switch_simplify_,
204 });
205
206 opt::OptPassConfig inline_opt = opt::OptPassConfig({
207 irpass.inline_,
208 });
209
210 OptPassGroupMap map(
211 {{"ad_eliminate", pynative_eliminate}, {"ad_inline", inline_opt}, {"ad_switch_simplify", switch_simplify}});
212
213 auto prim_bprop_opt_step_1 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_1", resource, map);
214 FuncGraphPtr func_graph = resource->func_graph();
215 ProfileExecute(MsProfile::GetProfile()->Step("prim_bprop_opt_step_1"), [&prim_bprop_opt_step_1, &func_graph]() {
216 func_graph = prim_bprop_opt_step_1->step(func_graph, true);
217 });
218 return func_graph;
219 }
220
PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib & irpass,const ResourcePtr & resource,const std::vector<bool> & need_grad_flags)221 FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource,
222 const std::vector<bool> &need_grad_flags) {
223 MS_EXCEPTION_IF_NULL(resource);
224 MS_EXCEPTION_IF_NULL(resource->func_graph());
225 OptPassGroupMap map;
226
227 opt::OptPassConfig special_op_simplify = opt::OptPassConfig({
228 irpass.switch_simplify_,
229 irpass.reduce_eliminate_,
230 irpass.tile_eliminate_,
231 irpass.arithmetic_simplify_,
232 });
233
234 opt::OptPassConfig inline_opt = opt::OptPassConfig({
235 irpass.inline_,
236 });
237
238 auto re_auto_monadwrapper = [](const FuncGraphPtr &root, const opt::OptimizerPtr &) -> bool {
239 return ReAutoMonad(root);
240 };
241
242 map.push_back({"ad_renormalize", opt::OptPassConfig::Renormalize()});
243 map.push_back({"ad_inline", inline_opt});
244 map.push_back({"ad_special_op_simplify", special_op_simplify});
245 map.push_back({"auto_monad_grad", opt::OptPassConfig(re_auto_monadwrapper)});
246 if (!need_grad_flags.empty()) {
247 // If func graph has not need_grad_flag_of_inputs attr, this graph has no need do this pass.
248 opt::OptPassConfig pynative_no_grad_eliminate = opt::OptPassConfig({
249 irpass.pynative_no_grad_eliminate_,
250 });
251
252 map.push_back({"pynative_no_grad_eliminate", pynative_no_grad_eliminate});
253 }
254
255 auto prim_bprop_opt_step_2 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_2", resource, map);
256 FuncGraphPtr func_graph = resource->func_graph();
257 ProfileExecute(MsProfile::GetProfile()->Step("prim_bprop_opt_step_2"), [&prim_bprop_opt_step_2, &func_graph]() {
258 func_graph = prim_bprop_opt_step_2->step(func_graph, true);
259 });
260 return func_graph;
261 }
262
JitBpropGraphPass(const ResourcePtr & resource,bool need_renormalize)263 FuncGraphPtr JitBpropGraphPass(const ResourcePtr &resource, bool need_renormalize) {
264 opt::irpass::OptimizeIRPassLib irpass;
265 opt::OptPassConfig grad_graph_opt = opt::OptPassConfig({
266 irpass.inline_,
267 irpass.list_to_tuple_eliminator_,
268 irpass.tuple_to_list_eliminator_,
269 irpass.tuple_list_get_set_item_eliminator_,
270 irpass.tuple_list_get_item_eliminator_,
271 irpass.tuple_list_set_item_eliminator_,
272 irpass.depend_value_elim_,
273 irpass.reshape_eliminate_,
274 irpass.switch_simplify_,
275 irpass.addn_zero_filter_,
276 irpass.ad_related_special_op_eliminate_,
277 });
278 opt::OptPassConfig fill_zeros_like = opt::OptPassConfig{irpass.zero_like_fill_zero_};
279 OptPassGroupMap map({
280 {"grad_graph_opt", grad_graph_opt},
281 {"zeros_like", fill_zeros_like},
282 });
283 if (need_renormalize) {
284 (void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
285 opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_};
286 (void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate));
287 }
288 MS_EXCEPTION_IF_NULL(resource);
289 auto func_graph = resource->func_graph();
290 auto graph_opt = opt::Optimizer::MakeOptimizer("jit_bprop_graph_opt", resource, map);
291 return graph_opt->step(func_graph, false);
292 }
293
FinalBpropGraphPass(const ResourcePtr & resource,bool has_control_flow)294 FuncGraphPtr FinalBpropGraphPass(const ResourcePtr &resource, bool has_control_flow) {
295 MS_EXCEPTION_IF_NULL(resource);
296 auto func_graph = resource->func_graph();
297
298 opt::irpass::OptimizeIRPassLib irpass;
299 OptPassGroupMap map;
300 opt::OptPassConfig inline_opt = opt::OptPassConfig({
301 irpass.inline_,
302 });
303 map.emplace_back("ad_inline", inline_opt);
304
305 opt::OptPassConfig grad_graph_opt = opt::OptPassConfig({
306 irpass.tuple_list_get_item_eliminator_,
307 irpass.zero_like_fill_zero_,
308 });
309 (void)map.emplace_back("grad_graph_opt", grad_graph_opt);
310
311 if (has_control_flow) {
312 opt::OptPassConfig env_eliminate = opt::OptPassConfig({
313 irpass.environ_get_eliminate_,
314 irpass.environ_get_add_eliminate_,
315 irpass.environ_get_set_eliminate_,
316 irpass.environ_get_depend_swap_,
317 irpass.environ_add_const_eliminate_,
318 });
319 (void)map.emplace_back(std::make_pair("env_eliminate", env_eliminate));
320 }
321 auto graph_opt = opt::Optimizer::MakeOptimizer("final_bprop_graph_opt", resource, map);
322 return graph_opt->step(func_graph, false);
323 }
324
325 namespace {
ReAutoMonadWrapper(const FuncGraphPtr & root,const opt::OptimizerPtr &)326 bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); }
327
parallel_mode()328 bool parallel_mode() {
329 #if defined(__linux__) && defined(WITH_BACKEND)
330 if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
331 return false;
332 }
333 #endif
334 std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
335 return (parallel_mode == parallel::kAutoParallel) || (parallel_mode == parallel::kSemiAutoParallel);
336 }
337
AddParallelRenormalize(OptPassGroupMap * map_a)338 void AddParallelRenormalize(OptPassGroupMap *map_a) {
339 if (parallel_mode()) {
340 auto parallel_end_opt =
341 find_if(map_a->begin(), map_a->end(), [](auto opt_pair) { return opt_pair.first == "meta_fg_expand"; });
342 if (parallel_end_opt != map_a->end()) {
343 opt::irpass::OptimizeIRPassLib irpass;
344 opt::OptPassConfig cast_eliminate_pass = opt::OptPassConfig({irpass.cast_eliminate_});
345 auto iter = map_a->insert(parallel_end_opt, {"cast_eliminate", cast_eliminate_pass});
346 (void)map_a->insert(iter, {"parallel_renormalize", opt::OptPassConfig::Renormalize()});
347 }
348 }
349 }
350
GetOptPassA1(const opt::irpass::OptimizeIRPassLib & irpass)351 opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
352 return opt::OptPassConfig({
353 irpass.partial_defer_inline_,
354 irpass.switch_defer_inline_,
355 irpass.switch_layer_defer_inline_,
356 irpass.switch_simplify_,
357 irpass.exchange_switch_depend_value_,
358 irpass.float_depend_g_call_,
359
360 // Safe inlining
361 irpass.inline_,
362 irpass.updatestate_useless_node_eliminater_,
363 irpass.updatestate_pure_node_eliminater_,
364 irpass.load_eliminater_,
365 irpass.stopgrad_eliminater_,
366 irpass.partial_eliminate_,
367 irpass.replace_applicator_,
368 irpass.convert_tensor_eliminate_,
369
370 // Miscellaneous
371 irpass.list_to_tuple_eliminator_,
372 irpass.tuple_to_list_eliminator_,
373 irpass.tuple_list_get_item_eliminator_,
374 irpass.make_slice_get_slice_eliminator_,
375 irpass.tuple_list_get_item_const_eliminator_,
376 irpass.tuple_list_set_item_eliminator_,
377 irpass.tuple_list_get_set_item_eliminator_,
378 irpass.tuple_list_get_item_depend_reorder_,
379 irpass.tuple_list_convert_item_index_to_positive_,
380 irpass.dict_get_item_eliminator_,
381 irpass.dict_get_item_const_eliminator_,
382 irpass.dict_set_item_eliminator_,
383
384 irpass.environ_get_eliminate_,
385 irpass.environ_get_add_eliminate_,
386 irpass.environ_get_set_eliminate_,
387 irpass.environ_get_depend_swap_,
388 irpass.environ_add_const_eliminate_,
389
390 irpass.cast_eliminate_,
391 irpass.reshape_eliminate_,
392 irpass.reduce_eliminate_,
393 irpass.tile_eliminate_,
394 irpass.transpose_eliminate_,
395 irpass.minmaximum_grad_,
396
397 // Arithmetic simplifications
398 irpass.arithmetic_simplify_,
399 irpass.addn_zero_filter_,
400 irpass.adjust_all_reduce_mul_add_,
401 irpass.accumulaten_eliminater_,
402
403 // Safe inlining
404 irpass.inline_,
405 irpass.updatestate_useless_node_eliminater_,
406 irpass.updatestate_pure_node_eliminater_,
407 irpass.load_eliminater_,
408 irpass.stopgrad_eliminater_,
409 irpass.print_const_string_wrapper_,
410 });
411 }
412
FlashSPFrontPass(const FuncGraphPtr & func_graph,const opt::OptimizerPtr & optimizer)413 bool FlashSPFrontPass(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer) {
414 if (func_graph->has_flag(parallel::FLASH_SP_RUN_ONCE_ONLY)) {
415 return false;
416 }
417 auto result = parallel::SetFlashSP(func_graph);
418 func_graph->set_flag(parallel::FLASH_SP_RUN_ONCE_ONLY, true);
419 return result;
420 }
421
GetOptPassesA(const opt::irpass::OptimizeIRPassLib & irpass)422 OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
423 opt::OptPassConfig a_1 = GetOptPassA1(irpass);
424 opt::OptPassConfig a_2 = opt::OptPassConfig(
425 {
426 irpass.switch_simplify_,
427 irpass.specialize_transform_,
428 irpass.merge_addn_,
429 irpass.compare_switch_simplify_,
430 irpass.addn_check_dump_,
431 irpass.float_tuple_getitem_switch_,
432 irpass.float_environ_get_switch_,
433 irpass.inline_,
434 irpass.updatestate_useless_node_eliminater_,
435 irpass.arithmetic_simplify_,
436 irpass.tuple_list_set_item_eliminator_,
437 irpass.tuple_list_get_item_eliminator_,
438 irpass.incorporate_call_,
439 irpass.incorporate_call_switch_,
440 irpass.environ_get_eliminate_,
441 irpass.depend_value_elim_,
442 irpass.all_reduce_const_elim_,
443 },
444 false, true);
445
446 opt::OptPassConfig before_grad = opt::OptPassConfig({irpass.j_node_and_user_rematch_});
447
448 opt::OptPassConfig a_after_grad = opt::OptPassConfig({irpass.inline_without_move_, irpass.stack_unstack_eliminate_});
449
450 opt::OptPassConfig a_3 = opt::OptPassConfig(
451 {
452 irpass.same_eliminate_,
453 irpass.check_bprop_eliminate_,
454 irpass.switch_layer_defer_inline_,
455 irpass.replace_applicator_,
456 irpass.row_tensor_add_zeros_like_,
457 irpass.mini_step_allgather_replace_,
458 irpass.micro_step_allgather_replace_,
459 irpass.split_environ_get_set_with_tuple_value_,
460 },
461 false, true);
462 opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_});
463 opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
464 opt::OptPassConfig after_resolve_pass = opt::OptPassConfig({irpass.replace_old_param_});
465 // Disable after_resolve_pass if Pre-Lift is enabled.
466 static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
467 if (enable_pre_lift) {
468 after_resolve_pass.set_disabled(true);
469 }
470 opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
471 opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
472 opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
473 opt::OptPassConfig recompute_prepare = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
474 opt::OptPassConfig get_grad = opt::OptPassConfig({irpass.get_grad_eliminate_});
475 opt::OptPassConfig cell_reuse_handle_not_recompute_node_pass =
476 opt::OptPassConfig({irpass.remove_not_recompute_node_}, false, true);
477
478 opt::OptPassConfig c_1 = opt::OptPassConfig({
479 irpass.switch_call_monad_eliminater_,
480 irpass.partial_eliminate_,
481 });
482 // Disable c_1 if Pre-Lift is not enabled.
483 if (!enable_pre_lift) {
484 c_1.set_disabled(true);
485 }
486 // Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
487 OptPassGroupMap map_a({{"expand_dump_flag", opt::OptPassConfig(opt::irpass::ExpandDumpFlag())},
488 {"switch_simplify", opt::OptPassConfig({irpass.switch_simplify_})},
489 {"a_1", a_1},
490 {"recompute_prepare", recompute_prepare},
491 {"updatestate_depend_eliminate", updatestate_depend_eliminate},
492 {"updatestate_assign_eliminate", updatestate_assign_eliminate},
493 {"updatestate_loads_eliminate", updatestate_loads_eliminate},
494 {"c_1", c_1},
495 {"parameter_eliminate", opt::OptPassConfig(opt::irpass::ParameterEliminator())},
496 {"a_2", a_2},
497 {"accelerated_algorithm", accelerated_algorithm},
498 {"shard", opt::OptPassConfig(parallel::Shard)},
499 {"meta_shard_fg_expand", opt::OptPassConfig(opt::irpass::ExpandMetaShardFg())},
500 {"shard_inline", opt::OptPassConfig({irpass.inline_})},
501 {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)},
502 {"parallel", opt::OptPassConfig(parallel::StepParallel)},
503 {"flash_sp", opt::OptPassConfig(FlashSPFrontPass)},
504 {"merge_comm", opt::OptPassConfig(parallel::MergeComm)},
505 {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
506 {"matmul_add_comm_reduction", opt::OptPassConfig(parallel::MatmulAddCommReduction)},
507 {"virtual_shard_identity", opt::OptPassConfig({irpass.virtual_shard_identity_})},
508 {"virtual_dataset", virtual_dataset},
509 {"get_grad_eliminate_", get_grad},
510 {"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
511 {"merge_forward", opt::OptPassConfig(ad::MergeForward)},
512 {"cell_reuse_recompute_pass", opt::OptPassConfig(opt::irpass::AddRecomputeNodes)},
513 {"cell_reuse_handle_not_recompute_node_pass", cell_reuse_handle_not_recompute_node_pass},
514 {"before_grad", before_grad},
515 {"meta_fg_expand", opt::OptPassConfig(opt::irpass::ExpandMetaFg())},
516 {"receive_attached", opt::OptPassConfig(parallel::IsolatedNodeAttach)},
517 {"after_resolve", after_resolve_pass},
518 {"a_after_grad", a_after_grad},
519 {"renormalize", opt::OptPassConfig::Renormalize()},
520 {"real_op_eliminate", opt::OptPassConfig({irpass.real_op_eliminate_})},
521 {"add_forward_monad_depend", opt::OptPassConfig(opt::irpass::AddForwardMonadDepend)},
522 {"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)},
523 {"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
524 {"cse", opt::OptPassConfig(opt::CSEPass(false))},
525 {"a_3", a_3}});
526 AddParallelRenormalize(&map_a);
527 return map_a;
528 }
529
GetA1A2(const opt::irpass::OptimizeIRPassLib & irpass)530 OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
531 auto opt_a = GetOptPassesA(irpass);
532 constexpr auto a1_a2_len = 10;
533 OptPassGroupMap a1_a2(opt_a.begin(), opt_a.begin() + a1_a2_len);
534 return a1_a2;
535 }
536
GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib & irpass)537 OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
538 opt::OptPassConfig c_1 = opt::OptPassConfig({
539 // Safe inlining,
540 irpass.inline_,
541 irpass.updatestate_useless_node_eliminater_,
542 irpass.updatestate_pure_node_eliminater_,
543 irpass.load_eliminater_,
544 irpass.switch_call_monad_eliminater_,
545 irpass.stopgrad_eliminater_,
546 irpass.partial_eliminate_,
547 irpass.slice_to_tuple_,
548 });
549 opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
550 opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
551 opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
552
553 OptPassGroupMap map_a({{"c_1", c_1},
554 {"parameter_eliminate", opt::OptPassConfig(opt::irpass::ParameterEliminator())},
555 {"updatestate_depend_eliminate", updatestate_depend_eliminate},
556 {"updatestate_assign_eliminate", updatestate_assign_eliminate},
557 {"updatestate_loads_eliminate", updatestate_loads_eliminate},
558 {"cse", opt::OptPassConfig(opt::CSEPass(false))},
559 {"renormalize", opt::OptPassConfig::Renormalize()}});
560
561 return map_a;
562 }
563
GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib & irpass)564 OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
565 opt::OptPassConfig d_1 = opt::OptPassConfig({
566 irpass.call_graph_tuple_transform_,
567 irpass.list_to_tuple_eliminator_,
568 irpass.tuple_to_list_eliminator_,
569 irpass.tuple_list_get_item_eliminator_,
570 irpass.tuple_list_get_item_const_eliminator_,
571 irpass.tuple_list_set_item_eliminator_,
572 irpass.tuple_list_get_set_item_eliminator_,
573 irpass.tuple_list_get_item_depend_reorder_,
574 irpass.tuple_list_convert_item_index_to_positive_,
575 });
576
577 OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
578
579 return map_a;
580 }
581
GetOptPassesB(const opt::irpass::OptimizeIRPassLib & irpass)582 OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
583 opt::OptPassConfig b_1 = opt::OptPassConfig({irpass.zero_like_fill_zero_,
584 irpass.list_to_tuple_eliminator_,
585 irpass.tuple_to_list_eliminator_,
586 irpass.tuple_list_get_item_eliminator_,
587 irpass.tuple_list_get_item_const_eliminator_,
588 irpass.tuple_list_set_item_eliminator_,
589 irpass.tuple_list_get_set_item_eliminator_,
590 irpass.tuple_list_get_item_depend_reorder_,
591 irpass.tuple_list_convert_item_index_to_positive_,
592 irpass.make_slice_get_slice_eliminator_,
593 irpass.float_tuple_getitem_switch_,
594 irpass.reset_defer_inline_,
595 irpass.inline_,
596 irpass.updatestate_useless_node_eliminater_,
597 irpass.updatestate_pure_node_eliminater_,
598 irpass.load_eliminater_,
599 irpass.stopgrad_eliminater_,
600 irpass.special_op_eliminate_,
601 irpass.environ_get_eliminate_,
602 irpass.environ_get_add_eliminate_,
603 irpass.environ_get_set_eliminate_,
604 irpass.environ_get_depend_swap_,
605 irpass.environ_add_const_eliminate_,
606 irpass.value_based_eliminate_,
607 irpass.parallel_virtual_node_,
608 irpass.const_output_eliminate_},
609 false, true);
610 opt::OptPassConfig b_2 = opt::OptPassConfig({
611 irpass.row_tensor_eliminate_,
612 });
613 opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
614 opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
615 opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
616 OptPassGroupMap map({
617 {"b_1", b_1},
618 {"b_2", b_2},
619 {"updatestate_depend_eliminate", updatestate_depend_eliminate},
620 {"updatestate_assign_eliminate", updatestate_assign_eliminate},
621 {"updatestate_loads_eliminate", updatestate_loads_eliminate},
622 {"renormalize", opt::OptPassConfig::Renormalize()},
623 {"cse", opt::OptPassConfig(opt::CSEPass(false))},
624 });
625 return map;
626 }
627
GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib & irpass)628 OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &irpass) {
629 opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
630 irpass.pynative_eliminate_,
631 });
632
633 OptPassGroupMap map({
634 {"pynative_eliminate", pynative_eliminate},
635 });
636 return map;
637 }
638
GetOptPassesC(const opt::irpass::OptimizeIRPassLib &)639 OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &) {
640 return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
641 }
642
GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib & irpass)643 OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
644 auto opt_a = GetOptPassesA(irpass);
645 auto a3 = opt_a[opt_a.size() - 1];
646 OptPassGroupMap map({
647 {"renormalize", opt::OptPassConfig::Renormalize()},
648 {"cse", opt::OptPassConfig(opt::CSEPass(false))},
649 {a3},
650 });
651 return map;
652 }
653
GetGradPartialTransformPhases()654 OptPassGroupMap GetGradPartialTransformPhases() {
655 opt::irpass::GradPartialPassLib irpass;
656 auto grad_partial_transform = opt::OptPassConfig({irpass.grad_partial_transform_});
657 opt::OptPassGroupMap grad_partial_transform_map({{"grad_partial_transform", grad_partial_transform}});
658 return grad_partial_transform_map;
659 }
660
GetPreparePhases(const opt::irpass::OptimizeIRPassLib & irpass)661 OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
662 opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_});
663 OptPassGroupMap map({{"prepare_group", prepare_group}});
664 return map;
665 }
666
GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &)667 OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &) {
668 OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
669 return map;
670 }
671
GetSymbolEngineOptPass(const opt::irpass::OptimizeIRPassLib & irpass)672 OptPassGroupMap GetSymbolEngineOptPass(const opt::irpass::OptimizeIRPassLib &irpass) {
673 if (common::GetEnv("MS_SYMBOL_ENGINE_OPTIMIZE") == "off") {
674 MS_LOG(INFO) << "SymbolEngineOptimizer is disabled.";
675 return OptPassGroupMap();
676 }
677 OptPassGroupMap map({{"build", opt::OptPassConfig(opt::irpass::SymbolEngineBuilder())},
678 {"elim_shapecalc", opt::OptPassConfig({irpass.elim_shapecalc_of_broadcastargs_})},
679 {"elim_not_effective", opt::OptPassConfig({irpass.elim_not_effective_node_})},
680 {"opt_reshape", opt::OptPassConfig({irpass.opt_reshape_})},
681 {"fold_const_symbol", opt::OptPassConfig({irpass.fold_const_symbol_})},
682 {"shape_op_cse", opt::OptPassConfig(opt::irpass::ShapeOpCse())},
683 {"renormalize", opt::OptPassConfig::Renormalize()}});
684 return map;
685 }
686
687 static mindspore::HashMap<std::string, std::shared_ptr<Optimizer>> g_pass_opts = {};
688
InitOpt(const ResourcePtr & resource)689 void InitOpt(const ResourcePtr &resource) {
690 if (g_pass_opts.size() == 0) {
691 opt::irpass::OptimizeIRPassLib irpass;
692 g_pass_opts["a1a2"] = Optimizer::MakeOptimizer("a1a2", resource, GetA1A2(irpass));
693 g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", resource, GetOptPassesA(irpass));
694 g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", resource, GetOptPassesB(irpass), false, true);
695 g_pass_opts["opt_after_cconv"] =
696 Optimizer::MakeOptimizer("opt_after_cconv", resource, GetOptPassesAfterCconv(irpass), false, true);
697 g_pass_opts["opt_trans_graph"] =
698 Optimizer::MakeOptimizer("opt_trans_graph", resource, GetOptPassesTransformGraph(irpass), true, true);
699 g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", resource, GetOptPassesC(irpass));
700 g_pass_opts["opt_grad_epilogue"] =
701 Optimizer::MakeOptimizer("opt_grad_epilogue", resource, GetOptPynativeGradEpiloguePhases(irpass), true, false);
702 g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", resource, GetPreparePhases(irpass));
703 g_pass_opts["opt_after_recompute"] =
704 Optimizer::MakeOptimizer("opt_after_recompute", resource, GetAfterRecomputePass(irpass));
705 g_pass_opts["symbol_engine_opt"] =
706 Optimizer::MakeOptimizer("symbol_engine_opt", resource, GetSymbolEngineOptPass(irpass), true, true);
707 }
708 }
709 } // namespace
710
ReclaimOptimizer()711 void ReclaimOptimizer() {
712 for (auto &opt : g_pass_opts) {
713 opt.second = nullptr;
714 }
715 g_pass_opts.clear();
716 }
717
OptPassGroup(const ResourcePtr & resource,const std::string & name)718 bool OptPassGroup(const ResourcePtr &resource, const std::string &name) {
719 MS_EXCEPTION_IF_NULL(resource);
720 if (resource->func_graph() == nullptr) {
721 MS_LOG(ERROR) << "Opt passes int64_t error";
722 return false;
723 }
724
725 FuncGraphPtr func_graph = resource->func_graph();
726 MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", "
727 << func_graph->get_return()->DebugString(true);
728 InitOpt(resource);
729 if (g_pass_opts.find(name) != g_pass_opts.end()) {
730 resource->set_func_graph(g_pass_opts[name]->step(func_graph));
731 }
732 // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to
733 // resource->args_abs_ yet. So if any later pass or action want to use that variable, it should be set here.
734 return true;
735 }
736
OptPassA1A2(const ResourcePtr & resource)737 bool OptPassA1A2(const ResourcePtr &resource) { return OptPassGroup(resource, "a1a2"); }
OptPassAGroup(const ResourcePtr & resource)738 bool OptPassAGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_a"); }
OptPassBGroup(const ResourcePtr & resource)739 bool OptPassBGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_b"); }
OptPassAfterCconvGroup(const ResourcePtr & resource)740 bool OptPassAfterCconvGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_after_cconv"); }
OptPassTransformGraphGroup(const ResourcePtr & resource)741 bool OptPassTransformGraphGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_trans_graph"); }
ControlGroup(const ResourcePtr & resource)742 bool ControlGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_control"); }
PrepareGroup(const ResourcePtr & resource)743 bool PrepareGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_prepare"); }
OptAfterRecomputeGroup(const ResourcePtr & resource)744 bool OptAfterRecomputeGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_after_recompute"); }
745
OptPassRNGroup(const ResourcePtr & resource)746 bool OptPassRNGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "renormal"); }
SymEngOptGroup(const ResourcePtr & resource)747 bool SymEngOptGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "symbol_engine_opt"); }
748
OptPassGradEpilogueGroup(const ResourcePtr & resource)749 bool OptPassGradEpilogueGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_grad_epilogue"); }
750
AddRecomputationPass(const ResourcePtr & resource)751 bool AddRecomputationPass(const ResourcePtr &resource) {
752 auto context = MsContext::GetInstance();
753 MS_EXCEPTION_IF_NULL(context);
754 if (context->CellReuseLevel() != CellReuseLevel::kNoCellReuse) {
755 return true;
756 }
757 MS_EXCEPTION_IF_NULL(resource);
758 opt::InsertRecomputedNodes(resource->func_graph());
759 return true;
760 }
761
SliceRecomputeActivationPass(const ResourcePtr & resource)762 bool SliceRecomputeActivationPass(const ResourcePtr &resource) {
763 MS_EXCEPTION_IF_NULL(resource);
764 opt::SliceRecomputedActivationNodes(resource->func_graph());
765 return true;
766 }
767
GroupedPairwiseExchangeAllToAllPass(const ResourcePtr & resource)768 bool GroupedPairwiseExchangeAllToAllPass(const ResourcePtr &resource) {
769 MS_EXCEPTION_IF_NULL(resource);
770 opt::SetGroupedPairwiseExchangeAllToAll(resource);
771 return true;
772 }
773
SliceReuseRecomputedActivationPass(const ResourcePtr & resource)774 bool SliceReuseRecomputedActivationPass(const ResourcePtr &resource) {
775 MS_EXCEPTION_IF_NULL(resource);
776 parallel::SliceReuseRecomputedActivationNodes(resource->func_graph());
777 return true;
778 }
779
LabelMicroInterleavedIndexPass(const ResourcePtr & resource)780 bool LabelMicroInterleavedIndexPass(const ResourcePtr &resource) {
781 MS_EXCEPTION_IF_NULL(resource);
782 parallel::LabelMicroInterleavedIndex(resource->func_graph());
783 return true;
784 }
785
OverlapRecomputeAllGatherAndFlashAttentionGradPass(const ResourcePtr & resource)786 bool OverlapRecomputeAllGatherAndFlashAttentionGradPass(const ResourcePtr &resource) {
787 MS_EXCEPTION_IF_NULL(resource);
788 parallel::OverlapRecomputeAllGatherAndFlashAttentionGrad(resource->func_graph());
789 return true;
790 }
791
OptimizeParallelAllGatherCommPass(const ResourcePtr & resource)792 bool OptimizeParallelAllGatherCommPass(const ResourcePtr &resource) {
793 MS_EXCEPTION_IF_NULL(resource);
794 parallel::OptimizeParallelAllGatherComm(resource->func_graph());
795 return true;
796 }
797
LabelFineGrainedInterleavedIndexPass(const ResourcePtr & resource)798 bool LabelFineGrainedInterleavedIndexPass(const ResourcePtr &resource) {
799 MS_EXCEPTION_IF_NULL(resource);
800 parallel::LabelFineGrainedInterleavedIndex(resource->func_graph());
801 return true;
802 }
803
AssignAddOpt(const ResourcePtr & resource)804 bool AssignAddOpt(const ResourcePtr &resource) {
805 MS_EXCEPTION_IF_NULL(resource);
806 FuncGraphPtr func_graph = resource->func_graph();
807 MS_EXCEPTION_IF_NULL(func_graph);
808 parallel::AssignAddOpt(func_graph);
809 auto ms_context = MsContext::GetInstance();
810 auto enable_concat_eliminate = ms_context->get_param<bool>(MS_CTX_ENABLE_CONCAT_ELIMINATE_OPT);
811 if (!enable_concat_eliminate) {
812 return true;
813 }
814 OptPassGroupMap map({{"renormalize", opt::OptPassConfig({opt::OptPassConfig::Renormalize()})}});
815 auto renormalize = opt::Optimizer::MakeOptimizer("renormalize", resource, map);
816 (void)renormalize->step(func_graph, false);
817 return true;
818 }
819
PartialUnusedArgsEliminatePass(const ResourcePtr & resource)820 bool PartialUnusedArgsEliminatePass(const ResourcePtr &resource) {
821 MS_EXCEPTION_IF_NULL(resource);
822 FuncGraphPtr func_graph = resource->func_graph();
823 auto opt = opt::irpass::PartialUnusedArgsEliminate();
824 auto changed = opt(func_graph);
825 if (changed) {
826 OptPassGroupMap map({{"renormalize", opt::OptPassConfig({opt::OptPassConfig::Renormalize()})}});
827 auto renormalize = opt::Optimizer::MakeOptimizer("renormalize", resource, map);
828 (void)renormalize->step(func_graph, false);
829 }
830 return true;
831 }
832
MergeCastOpt(const ResourcePtr & resource)833 bool MergeCastOpt(const ResourcePtr &resource) {
834 MS_EXCEPTION_IF_NULL(resource);
835 parallel::MergeCastOpt(resource->func_graph());
836 return true;
837 }
838
ForceFp32Comm(const ResourcePtr & resource)839 bool ForceFp32Comm(const ResourcePtr &resource) {
840 MS_EXCEPTION_IF_NULL(resource);
841 parallel::Float32Redistribution(resource->func_graph());
842 return true;
843 }
844
SwapDpAllReduceReduceScatterPass(const ResourcePtr & resource)845 bool SwapDpAllReduceReduceScatterPass(const ResourcePtr &resource) {
846 MS_EXCEPTION_IF_NULL(resource);
847 parallel::SwapDpAllreduceReduceScatter(resource->func_graph());
848 return true;
849 }
850
RemoveCastBeforeAssignAdd(const ResourcePtr & resource)851 bool RemoveCastBeforeAssignAdd(const ResourcePtr &resource) {
852 MS_EXCEPTION_IF_NULL(resource);
853 parallel::RemoveCastBeforeAssignAdd(resource->func_graph());
854 return true;
855 }
856
BiasAddCommSwap(const ResourcePtr & resource)857 bool BiasAddCommSwap(const ResourcePtr &resource) {
858 MS_EXCEPTION_IF_NULL(resource);
859 parallel::BiasAddCommSwap(resource->func_graph());
860 return true;
861 }
862
ReorderSendRecvBetweenFpBpPass(const ResourcePtr & resource)863 bool ReorderSendRecvBetweenFpBpPass(const ResourcePtr &resource) {
864 MS_EXCEPTION_IF_NULL(resource);
865 parallel::ReorderSendRecvBetweenFpBp(resource->func_graph());
866 return true;
867 }
868
CompCommSchedulingPass(const ResourcePtr & resource)869 bool CompCommSchedulingPass(const ResourcePtr &resource) {
870 MS_EXCEPTION_IF_NULL(resource);
871 opt::CompCommScheduling(resource->func_graph());
872 return true;
873 }
874
MicroInterLeavedOrderControlPass(const ResourcePtr & resource)875 bool MicroInterLeavedOrderControlPass(const ResourcePtr &resource) {
876 MS_EXCEPTION_IF_NULL(resource);
877 parallel::MicroInterleavedOrderControl(resource->func_graph());
878 return true;
879 }
880
OverlapGradCommPass(const ResourcePtr & resource)881 bool OverlapGradCommPass(const ResourcePtr &resource) {
882 MS_EXCEPTION_IF_NULL(resource);
883 parallel::OverlapGradComm(resource->func_graph());
884 return true;
885 }
886
FullMicroInterLeavedOrderControlPass(const ResourcePtr & resource)887 bool FullMicroInterLeavedOrderControlPass(const ResourcePtr &resource) {
888 MS_EXCEPTION_IF_NULL(resource);
889 parallel::FullMicroInterleavedOrderControl(resource->func_graph());
890 return true;
891 }
892
SplitMatmulCommElementwiseOpFpPass(const ResourcePtr & resource)893 bool SplitMatmulCommElementwiseOpFpPass(const ResourcePtr &resource) {
894 MS_EXCEPTION_IF_NULL(resource);
895 parallel::SplitMatmulCommElementwiseFp(resource->func_graph());
896 return true;
897 }
898
SplitLayerNormCommFpPass(const ResourcePtr & resource)899 bool SplitLayerNormCommFpPass(const ResourcePtr &resource) {
900 MS_EXCEPTION_IF_NULL(resource);
901 parallel::SplitLayerNormCommFp(resource->func_graph());
902 return true;
903 }
904
CommOpAddAttrs(const ResourcePtr & resource)905 bool CommOpAddAttrs(const ResourcePtr &resource) {
906 MS_EXCEPTION_IF_NULL(resource);
907 opt::CommOpAttrs(resource->func_graph());
908 return true;
909 }
910
ProcessSendRecvForGE(const ResourcePtr & resource)911 bool ProcessSendRecvForGE(const ResourcePtr &resource) {
912 MS_EXCEPTION_IF_NULL(resource);
913 opt::ProcessSendRecvForGE(resource->func_graph());
914 return true;
915 }
916
AddCommOpReusePass(const ResourcePtr & resource)917 bool AddCommOpReusePass(const ResourcePtr &resource) {
918 MS_EXCEPTION_IF_NULL(resource);
919 opt::AddCommOpReuseTag(resource->func_graph());
920 return true;
921 }
922
OverlapOptShardInPipelinePass(const ResourcePtr & resource)923 bool OverlapOptShardInPipelinePass(const ResourcePtr &resource) {
924 MS_EXCEPTION_IF_NULL(resource);
925 parallel::OverlapOptShardInPipeline(resource->func_graph());
926 return true;
927 }
928
BeginEndOverlapInlinePass(const ResourcePtr & resource)929 bool BeginEndOverlapInlinePass(const ResourcePtr &resource) {
930 auto ms_context = MsContext::GetInstance();
931 auto is_enable = ms_context->get_param<bool>(MS_CTX_ENABLE_BEGIN_END_INLINE_OPT);
932 if (!is_enable) {
933 return true;
934 }
935 MS_EXCEPTION_IF_NULL(resource);
936 FuncGraphPtr func_graph = resource->func_graph();
937 MS_EXCEPTION_IF_NULL(func_graph);
938 parallel::BeginEndOverlapInlineOpt(resource->func_graph());
939 opt::irpass::OptimizeIRPassLib irpass;
940 opt::OptPassConfig get_item_eliminator_pass = opt::OptPassConfig({irpass.tuple_list_get_item_eliminator_});
941 OptPassGroupMap map({{"get_item_eliminator", get_item_eliminator_pass}});
942 auto get_item_eliminator = opt::Optimizer::MakeOptimizer("get_item_eliminator", resource, map);
943 (void)get_item_eliminator->step(func_graph, false);
944 return true;
945 }
946
OverlapGradMatmulAndGradAllreduce(const ResourcePtr & resource)947 bool OverlapGradMatmulAndGradAllreduce(const ResourcePtr &resource) {
948 MS_EXCEPTION_IF_NULL(resource);
949 parallel::OverlapGradMatmulAndGradAllreduce(resource->func_graph());
950 return true;
951 }
952
OverlapOptShardGradInPipelinePass(const ResourcePtr & resource)953 bool OverlapOptShardGradInPipelinePass(const ResourcePtr &resource) {
954 MS_EXCEPTION_IF_NULL(resource);
955 parallel::OverlapOptShardGradInPipeline(resource->func_graph());
956 return true;
957 }
958
HandleGroupInfoPass(const ResourcePtr & resource)959 bool HandleGroupInfoPass(const ResourcePtr &resource) {
960 MS_EXCEPTION_IF_NULL(resource);
961 parallel::HandleGroupInfo();
962 return true;
963 }
964
OverlapRecomputeAndGradModelParallel(const ResourcePtr & resource)965 bool OverlapRecomputeAndGradModelParallel(const ResourcePtr &resource) {
966 MS_EXCEPTION_IF_NULL(resource);
967 parallel::OverlapRecomputeAndGradModelParallel(resource->func_graph());
968 return true;
969 }
970
AddCacheEmbeddingPass(const ResourcePtr & resource)971 bool AddCacheEmbeddingPass(const ResourcePtr &resource) {
972 MS_EXCEPTION_IF_NULL(resource);
973 #if defined(__linux__) && defined(WITH_BACKEND)
974 if (ps::PSContext::instance()->is_ps_mode()) {
975 return true;
976 }
977 #endif
978 FuncGraphPtr func_graph = resource->func_graph();
979 MS_EXCEPTION_IF_NULL(func_graph);
980
981 parallel::AddCacheEmbedding(func_graph);
982 if (func_graph->has_flag(GRAPH_FLAG_CACHE_ENABLE)) {
983 auto params = func_graph->parameters();
984 AbstractBasePtrList args_abs_list;
985 (void)std::for_each(params.begin(), params.end(),
986 [&args_abs_list](const AnfNodePtr &node) { args_abs_list.push_back(node->abstract()); });
987 func_graph = pipeline::Renormalize(resource, func_graph, args_abs_list);
988 }
989 return true;
990 }
991
RemoveValueNodeDuplicationsPass(const ResourcePtr & resource)992 bool RemoveValueNodeDuplicationsPass(const ResourcePtr &resource) {
993 MS_EXCEPTION_IF_NULL(resource);
994 if (resource->func_graph() == nullptr) {
995 MS_LOG(INTERNAL_EXCEPTION) << "Remove value node duplications error.";
996 }
997 auto manager = resource->manager();
998 HashCache hash_cache;
999 HashValue hashes;
1000 // Remove duplicated value nodes across all graphs in manager
1001 const auto &node_user_map = manager->node_users();
1002 for (auto &fg : manager->func_graphs()) {
1003 auto value_nodes = fg->value_nodes();
1004 for (const auto &value_pair : value_nodes) {
1005 auto &users = node_user_map.at(value_pair.first);
1006 auto prim = GetValueNode<PrimitivePtr>(value_pair.first);
1007 if (IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
1008 continue;
1009 }
1010 // For data parallel with some parameters redundant, the allreduce will share the same value node
1011 // which will raise an error when do allreduce fusion, so the solution is to make the allreduce's value node
1012 // not be removed, if we found the fusion tag.
1013 if (users.size() == 1) {
1014 auto cnode = users.front().first->cast<CNodePtr>();
1015 if (IsPrimitiveCNode(cnode, prim::kPrimAllReduce) && cnode->size() > 1 && cnode->input(1)->isa<ValueNode>()) {
1016 auto allreduce_prim = GetCNodePrimitive(users.front().first);
1017 auto attrs = allreduce_prim->attrs();
1018 auto fusion_id = attrs.find(mindspore::parallel::FUSION);
1019 if (fusion_id != attrs.end() && GetValue<int64_t>(fusion_id->second) > 0) {
1020 continue;
1021 }
1022 }
1023 }
1024 TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
1025 }
1026 }
1027 return true;
1028 }
1029
CconvPass(const ResourcePtr & resource)1030 bool CconvPass(const ResourcePtr &resource) {
1031 MS_EXCEPTION_IF_NULL(resource);
1032 MS_EXCEPTION_IF_NULL(resource->func_graph());
1033 FuncGraphPtr func_graph = resource->func_graph();
1034 FuncGraphPtr new_fg = LiftingClone(func_graph);
1035 resource->set_func_graph(new_fg);
1036 return true;
1037 }
1038
PipelineSplitPass(const ResourcePtr & resource)1039 bool PipelineSplitPass(const ResourcePtr &resource) { return PipelineSplit(resource); }
1040
ParallelVirtualDatasetPass(const ResourcePtr & resource)1041 bool ParallelVirtualDatasetPass(const ResourcePtr &resource) { return ParallelVirtualDataset(resource); }
1042
PipelineParallelScheduler(const ResourcePtr & resource)1043 bool PipelineParallelScheduler(const ResourcePtr &resource) {
1044 MS_EXCEPTION_IF_NULL(resource);
1045 auto root = resource->func_graph();
1046 auto parallel_context = parallel::ParallelContext::GetInstance();
1047 MS_EXCEPTION_IF_NULL(parallel_context);
1048 auto parallel_mode = parallel_context->parallel_mode();
1049 if (parallel_mode != parallel::kSemiAutoParallel && parallel_mode != parallel::kAutoParallel) {
1050 MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
1051 return true;
1052 }
1053 auto is_pp_interleave = parallel_context->pipeline_interleave();
1054 auto stage_num = parallel_context->pipeline_stage_split_num();
1055 if (is_pp_interleave && stage_num > 1) {
1056 auto manager = resource->manager();
1057 auto stage = parallel::InferStage();
1058 auto pp_scheduler = parallel_context->pipeline_scheduler();
1059 std::shared_ptr<parallel::PipelineScheduler> scheduler = nullptr;
1060 if (pp_scheduler == parallel::kPipeline1F1B) {
1061 scheduler = std::make_shared<parallel::InterleavedScheduler>(manager, root, stage, stage_num);
1062 } else if (pp_scheduler == parallel::kPipelineGpipe) {
1063 scheduler = std::make_shared<parallel::GpipeInterleavedScheduler>(manager, root, stage, stage_num);
1064 } else {
1065 MS_LOG(EXCEPTION) << "Unsupported pipeline parallel scheduler: " << pp_scheduler;
1066 }
1067 scheduler->GetBorderNode();
1068 scheduler->Reorder();
1069 }
1070 opt::ProcessSendRecvForGE(root);
1071 return true;
1072 }
1073
AutoParallelPass(const ResourcePtr & resource)1074 bool AutoParallelPass(const ResourcePtr &resource) {
1075 auto func_graph = resource->func_graph();
1076 auto opt = opt::Optimizer::MakeEmptyOptimizer(resource);
1077 return parallel::StepAutoParallel(func_graph, opt);
1078 }
1079
AutoParallelSymbolPassWithReNormalize(const ResourcePtr & resource)1080 bool AutoParallelSymbolPassWithReNormalize(const ResourcePtr &resource) {
1081 // 1, auto parallel; 2, dynamic shape
1082 auto func_graph = resource->func_graph();
1083 if (!parallel::IsParallelDynamicShape(func_graph)) {
1084 return true;
1085 }
1086 MS_LOG(INFO) << "symbol pass for parallel begin";
1087 // must be bind with renormalize
1088 OptPassGroupMap opt_map({{"renormalize", opt::OptPassConfig::Renormalize()},
1089 {"build", opt::OptPassConfig(opt::irpass::SymbolEngineBuilder())}});
1090 auto opt = opt::Optimizer::MakeOptimizer("parallel-infer-symbol", resource, opt_map, true);
1091 (void)opt->step(func_graph, false);
1092 MS_LOG(INFO) << "symbol pass for parallel end";
1093 return true;
1094 }
1095
ValidatePass(const ResourcePtr & resource)1096 bool ValidatePass(const ResourcePtr &resource) {
1097 MS_EXCEPTION_IF_NULL(resource);
1098 MS_EXCEPTION_IF_NULL(resource->func_graph());
1099 FuncGraphPtr func_graph = resource->func_graph();
1100 Validate(func_graph);
1101 return true;
1102 }
1103
GradPartialTransformPass(const ResourcePtr & resource)1104 bool GradPartialTransformPass(const ResourcePtr &resource) {
1105 MS_EXCEPTION_IF_NULL(resource);
1106 FuncGraphPtr func_graph = resource->func_graph();
1107 MS_EXCEPTION_IF_NULL(func_graph);
1108 auto grad_partial_transform_map = GetGradPartialTransformPhases();
1109 auto grad_partial_transform =
1110 opt::Optimizer::MakeOptimizer("grad_partial_transform", resource, grad_partial_transform_map);
1111 (void)grad_partial_transform->step(func_graph, false);
1112 return true;
1113 }
1114
PynativeOptPass(const ResourcePtr & resource)1115 bool PynativeOptPass(const ResourcePtr &resource) {
1116 MS_EXCEPTION_IF_NULL(resource);
1117 FuncGraphPtr func_graph = resource->func_graph();
1118 MS_EXCEPTION_IF_NULL(func_graph);
1119 opt::irpass::OptimizeIRPassLib irpass;
1120 auto pynative_opt = GetOptPassesPynativeElim(irpass);
1121 auto pynative_opt_opt = opt::Optimizer::MakeOptimizer("pynative_opt", resource, pynative_opt);
1122 (void)pynative_opt_opt->step(func_graph, false);
1123 return true;
1124 }
1125
EliminateSpecialOpOptPass(const ResourcePtr & resource)1126 bool EliminateSpecialOpOptPass(const ResourcePtr &resource) {
1127 MS_EXCEPTION_IF_NULL(resource);
1128 auto func_graph = resource->func_graph();
1129 MS_EXCEPTION_IF_NULL(func_graph);
1130 opt::irpass::OptimizeIRPassLib irpass;
1131 opt::OptPassConfig ad_related_special_op_eliminate = opt::OptPassConfig({
1132 irpass.ad_related_special_op_eliminate_,
1133 });
1134 opt::OptPassConfig mutable_op_eliminate = opt::OptPassConfig({
1135 irpass.mutable_op_eliminate_,
1136 });
1137 opt::OptPassConfig convert_tensor_op_eliminate = opt::OptPassConfig({
1138 irpass.convert_tensor_all_eliminate_,
1139 });
1140 OptPassGroupMap map({
1141 {"ad_related_special_op_eliminate", ad_related_special_op_eliminate},
1142 {"mutable_op_eliminate", mutable_op_eliminate},
1143 {"convert_tensor_op_eliminate", convert_tensor_op_eliminate},
1144 });
1145 auto special_op_eliminate_opt = opt::Optimizer::MakeOptimizer("special_op_eliminate", resource, map);
1146 (void)special_op_eliminate_opt->step(func_graph, false);
1147 return true;
1148 }
1149
AutoMonadElimOptPass(const FuncGraphPtr & func_graph)1150 bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
1151 MS_EXCEPTION_IF_NULL(func_graph);
1152 MS_EXCEPTION_IF_NULL(func_graph->manager());
1153 auto resource = std::make_shared<pipeline::Resource>();
1154 resource->set_func_graph(func_graph);
1155 resource->set_manager(func_graph->manager());
1156
1157 // opt::irpass::OptimizeIRPassLib is not used here to avoid double free problems in external calls.
1158 opt::SubstitutionPtr updatestate_useless_node_eliminater =
1159 opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateUselessNodeEliminater>(),
1160 "updatestate_useless_node_eliminater", prim::kPrimUpdateState);
1161 opt::SubstitutionPtr updatestate_pure_node_eliminater =
1162 opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestatePureNodeEliminater>(),
1163 "updatestate_pure_node_eliminater", prim::kPrimUpdateState);
1164
1165 opt::OptPassConfig updatestate_eliminater = opt::OptPassConfig({
1166 updatestate_useless_node_eliminater,
1167 updatestate_pure_node_eliminater,
1168 });
1169 opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
1170 opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
1171 opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
1172 opt::OptPassGroupMap elim_map({
1173 {"updatestate_eliminater", updatestate_eliminater},
1174 {"updatestate_depend_eliminate", updatestate_depend_eliminate},
1175 {"updatestate_assign_eliminate", updatestate_assign_eliminate},
1176 {"updatestate_loads_eliminate", updatestate_loads_eliminate},
1177 {"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
1178 });
1179
1180 auto auto_monad_elim_opt = opt::Optimizer::MakeOptimizer("auto_monad_elim", resource, elim_map);
1181 (void)auto_monad_elim_opt->step(func_graph, false);
1182 return true;
1183 }
1184
EnvironConversionPass(const ResourcePtr & resource)1185 bool EnvironConversionPass(const ResourcePtr &resource) {
1186 MS_EXCEPTION_IF_NULL(resource);
1187 (void)opt::EnvironConversion(resource);
1188 return true;
1189 }
1190
1191 // Build service-side graph for embedding distributed cache based on Parameter Server.
AddEmbeddingCachePass(const ResourcePtr & resource)1192 bool AddEmbeddingCachePass(const ResourcePtr &resource) {
1193 MS_EXCEPTION_IF_NULL(resource);
1194 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
1195 if (!ps::PSContext::instance()->cache_enable() || !distributed::cluster::ClusterContext::instance()->initialized() ||
1196 !ps::PSContext::instance()->is_server()) {
1197 return true;
1198 }
1199
1200 FuncGraphPtr func_graph = resource->func_graph();
1201 MS_EXCEPTION_IF_NULL(func_graph);
1202 auto node = distributed::cluster::ClusterContext::instance()->node();
1203 MS_EXCEPTION_IF_NULL(node);
1204
1205 // 1. Build service-size graph.
1206 auto node_role = distributed::cluster::ClusterContext::instance()->node_role();
1207 uint32_t worker_num = ps::PSContext::instance()->worker_num();
1208 std::shared_ptr<parallel::PsEmbeddingCacheInserter> embedding_cache_inserter =
1209 std::make_shared<parallel::PsEmbeddingCacheInserter>(func_graph, static_cast<int64_t>(node->rank_id()), node_role,
1210 worker_num);
1211 if (!embedding_cache_inserter->Run()) {
1212 MS_LOG(ERROR) << "Insert ps embedding cache failed.";
1213 return false;
1214 }
1215
1216 // 2. Renomalize: Infer shape and Set abstract for all nodes in graph.
1217 abstract::AbstractBasePtrList args_abs;
1218 auto parameters = func_graph->parameters();
1219 (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs),
1220 [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
1221 FuncGraphPtr new_fg = Renormalize(resource, func_graph, args_abs);
1222 resource->set_func_graph(new_fg);
1223 resource->set_args_abs(args_abs);
1224 #endif
1225
1226 return true;
1227 }
1228
1229 std::vector<PassItem> kVmPasses = {
1230 {"py_interpret_to_execute", PyInterpretToExecutePass},
1231 {"rewriter_before_opt_a", RewriterBeforeOptAPass},
1232 {"opt_a", OptPassAGroup},
1233 {"py_interpret_to_execute_after_opt_a", PyInterpretToExecutePass},
1234 {"slice_cell_reuse_recomputed_activation", SliceReuseRecomputedActivationPass},
1235 {"rewriter_after_opt_a", RewriterAfterOptAPass},
1236 {"convert_after_rewriter", ConvertAfterRewriterPass},
1237 {"order_py_execute_after_rewriter", OrderPyExecuteAfterRewriterPass},
1238 {"opt_b", OptPassBGroup},
1239 {"optimize_parallel_all_gather_comm", OptimizeParallelAllGatherCommPass},
1240 {"cconv", CconvPass},
1241 {"opt_after_cconv", OptPassAfterCconvGroup},
1242 {"remove_dup_value", RemoveValueNodeDuplicationsPass},
1243 {"tuple_transform", OptPassTransformGraphGroup},
1244 {"partial_unused_args_eliminate", PartialUnusedArgsEliminatePass},
1245 {"add_cache_embedding", AddCacheEmbeddingPass},
1246 {"add_recomputation", AddRecomputationPass},
1247 {"cse_after_recomputation", OptAfterRecomputeGroup},
1248 {"environ_conv", EnvironConversionPass},
1249 {"swap_dp_allreduce_reducescatter", SwapDpAllReduceReduceScatterPass},
1250 {"bias_add_comm_swap", BiasAddCommSwap},
1251 {"label_micro_interleaved_index", LabelMicroInterleavedIndexPass},
1252 {"label_fine_grained_interleaved_index", LabelFineGrainedInterleavedIndexPass},
1253 {"merge_cast_opt", MergeCastOpt},
1254 {"slice_recompute_activation", SliceRecomputeActivationPass},
1255 {"micro_interleaved_order_control", MicroInterLeavedOrderControlPass},
1256 {"assign_add_opt", AssignAddOpt},
1257 {"ForceFp32Comm", ForceFp32Comm},
1258 {"remove_cast_before_assign_add", RemoveCastBeforeAssignAdd},
1259 {"full_micro_interleaved_order_control", FullMicroInterLeavedOrderControlPass},
1260 {"comp_comm_scheduling", CompCommSchedulingPass},
1261 {"reorder_send_recv_between_fp_bp", ReorderSendRecvBetweenFpBpPass},
1262 {"comm_op_add_attrs", CommOpAddAttrs},
1263 {"add_comm_op_reuse_tag", AddCommOpReusePass},
1264 {"overlap_opt_shard_in_pipeline", OverlapOptShardInPipelinePass},
1265 {"overlap_opt_shard_grad_in_pipeline", OverlapOptShardGradInPipelinePass},
1266 {"grouped_pairwise_exchange_alltoall", GroupedPairwiseExchangeAllToAllPass},
1267 {"overlap_recompute_and_grad_model_parallel", OverlapRecomputeAndGradModelParallel},
1268 {"overlap_grad_matmul_and_grad_allreduce", OverlapGradMatmulAndGradAllreduce},
1269 {"overlap_recompute_allgather_and_fa_grad", OverlapRecomputeAllGatherAndFlashAttentionGradPass},
1270 {"begin_end_overlap_inline", BeginEndOverlapInlinePass},
1271 {"overlap_grad_comm", OverlapGradCommPass},
1272 {"split_matmul_comm_elemetwise", SplitMatmulCommElementwiseOpFpPass},
1273 {"split_layernorm_comm", SplitLayerNormCommFpPass},
1274 // The pass cache hccl group, so the hccl group should be created before the pass
1275 {"handle_group_info", HandleGroupInfoPass},
1276 {"symbol_engine_optimizer", SymEngOptGroup}};
1277
1278 std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
1279 {"opt_b", OptPassBGroup},
1280 {"cconv", CconvPass},
1281 {"transform_top", TransformTopGraphPass},
1282 {"transform_graph", OptPassTransformGraphGroup}};
1283
1284 std::vector<PassItem> kInlinePasses = {{"rewriter_before_opt_a", RewriterBeforeOptAPass}, {"a1a2", OptPassA1A2}};
1285 } // namespace pipeline
1286 } // namespace mindspore
1287