• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/ps/pass.h"
18 
19 #include <memory>
20 #include <vector>
21 #include <string>
22 #include <algorithm>
23 
24 #include "mindspore/core/ops/other_ops.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "utils/hash_map.h"
27 #include "ir/func_graph_cloner.h"
28 #include "pipeline/jit/ps/parse/parse_base.h"
29 #include "pipeline/jit/ps/resource.h"
30 #include "pipeline/jit/ps/validator.h"
31 #include "pipeline/jit/ps/remove_value_node_dup.h"
32 #include "frontend/optimizer/opt.h"
33 #include "frontend/optimizer/optimizer.h"
34 #include "frontend/optimizer/cse_pass.h"
35 #include "frontend/optimizer/fallback_rewriter.h"
36 #include "frontend/optimizer/irpass.h"
37 #include "frontend/optimizer/graph_transform.h"
38 #include "frontend/optimizer/auto_monad_eliminate.h"
39 #include "include/common/fallback.h"
40 #include "include/common/utils/parallel_context.h"
41 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
42 #include "frontend/parallel/step_parallel.h"
43 #include "frontend/parallel/step_auto_parallel.h"
44 #include "frontend/parallel/graph_util/pipeline_split_utils.h"
45 #include "frontend/parallel/pipeline_transformer/pipeline_scheduler.h"
46 #include "frontend/parallel/pipeline_transformer/pipeline_interleave.h"
47 #include "frontend/parallel/pipeline_transformer/gpipe_interleave_scheduler.h"
48 #include "frontend/parallel/pass/merge_comm.h"
49 #include "frontend/parallel/cache_embedding/cache_embedding.h"
50 #include "frontend/parallel/cache_embedding/ps_embedding_cache_inserter.h"
51 #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
52 #include "frontend/parallel/shard/shard.h"
53 #include "frontend/parallel/pass/optimize_parallel_allgather_comm.h"
54 #include "frontend/parallel/pass/label_micro_interleaved_index.h"
55 #include "frontend/parallel/pass/label_fine_grained_interleaved_index.h"
56 #include "frontend/parallel/pass/reorder_send_recv_between_fp_bp.h"
57 #include "frontend/parallel/pass/micro_interleaved_order_control.h"
58 #include "frontend/parallel/pass/full_micro_interleaved_order_control.h"
59 #include "frontend/parallel/pass/overlap_recompute_allgather_and_flashattention_grad.h"
60 #include "frontend/parallel/pass/assign_add_opt.h"
61 #include "frontend/parallel/pass/float32_redistribution.h"
62 #include "frontend/parallel/pass/swap_dp_allreduce_reducescatter.h"
63 #include "frontend/parallel/pass/merge_cast_opt.h"
64 #include "frontend/parallel/pass/remove_cast_before_assign_add.h"
65 #include "frontend/parallel/pass/bias_add_comm_swap.h"
66 #include "frontend/parallel/pass/matmul_add_comm_reduction.h"
67 #include "frontend/parallel/pass/comp_comm_scheduling.h"
68 #include "frontend/parallel/pass/overlap_opt_shard_in_pipeline.h"
69 #include "frontend/parallel/pass/slice_activation_in_cell_share_recompute.h"
70 #include "frontend/parallel/pass/handle_group_info.h"
71 #include "frontend/parallel/pass/overlap_recompute_and_grad_model_parallel.h"
72 #include "frontend/parallel/pass/overlap_gradmatmul_and_gradallreduce.h"
73 #include "frontend/parallel/pass/begin_end_overlap_inline.h"
74 #include "frontend/parallel/pass/split_matmul_comm_elementwise_fp.h"
75 #include "frontend/parallel/pass/split_layernorm_comm_fp.h"
76 #include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
77 #include "frontend/parallel/pass/overlap_grad_comm.h"
78 #include "frontend/optimizer/recompute.h"
79 #include "frontend/optimizer/irpass/recompute.h"
80 #include "frontend/optimizer/slice_activation_in_recompute.h"
81 #include "frontend/optimizer/grouped_pairwise_exchange_alltoall.h"
82 #include "frontend/optimizer/comm_op_attrs.h"
83 #include "frontend/optimizer/process_send_recv_for_ge.h"
84 #include "frontend/optimizer/environ_conversion.h"
85 #include "frontend/optimizer/comm_op_reuse_tag.h"
86 #include "frontend/optimizer/py_interpret_to_execute.h"
87 #include "frontend/optimizer/flash_sp.h"
88 #include "utils/log_adapter.h"
89 #include "utils/compile_config.h"
90 #include "pipeline/jit/ps/pipeline_split.h"
91 #include "pipeline/pynative/pynative_execute.h"
92 #include "pipeline/jit/ps/static_analysis/auto_monad.h"
93 #include "frontend/optimizer/irpass/branch_culling.h"
94 #include "frontend/optimizer/irpass/meta_fg_eliminate.h"
95 #include "frontend/optimizer/irpass/gradient_eliminate.h"
96 #include "frontend/optimizer/irpass/shard_eliminate.h"
97 #include "frontend/optimizer/irpass/taylor_eliminate.h"
98 #include "frontend/optimizer/irpass/parameter_eliminate.h"
99 #include "frontend/optimizer/irpass/updatestate_eliminate.h"
100 #include "frontend/optimizer/irpass/expand_dump_flag.h"
101 #include "frontend/optimizer/irpass/symbol_engine_optimizer.h"
102 #include "frontend/optimizer/irpass/add_forward_monad_depend.h"
103 #if defined(__linux__) && defined(WITH_BACKEND)
104 #include "include/backend/distributed/ps/util.h"
105 #include "include/backend/distributed/ps/ps_context.h"
106 #endif
107 
108 namespace mindspore {
109 namespace pipeline {
110 using OptPassGroupMap = opt::OptPassGroupMap;
111 using Optimizer = opt::Optimizer;
112 using CompileGraphs = compile::CompileGraphs;
113 using abstract::AnalysisResult;
114 using mindspore::abstract::AnalysisContextPtr;
115 using mindspore::validator::Validate;
UpdateArgsSpec(const FuncGraphPtr & func_graph,const ResourcePtr & resource)116 void UpdateArgsSpec(const FuncGraphPtr &func_graph, const ResourcePtr &resource) {
117   MS_EXCEPTION_IF_NULL(func_graph);
118   MS_EXCEPTION_IF_NULL(resource);
119   abstract::AbstractBasePtrList args_abs;
120   const auto &parameters = func_graph->parameters();
121   args_abs.reserve(parameters.size());
122   (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs),
123                        [](const AnfNodePtr &p) { return p->abstract(); });
124   resource->set_args_abs(args_abs);
125 }
126 
PyInterpretToExecutePass(const ResourcePtr & resource)127 bool PyInterpretToExecutePass(const ResourcePtr &resource) {
128   const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() == kLax);
129   if (!allow_fallback_runtime) {
130     return true;
131   }
132   MS_EXCEPTION_IF_NULL(resource);
133   FuncGraphPtr func_graph = resource->func_graph();
134   MS_EXCEPTION_IF_NULL(func_graph);
135   (void)opt::PyInterpretToExecute(resource);
136   UpdateArgsSpec(func_graph, resource);
137   return true;
138 }
139 
RewriterBeforeOptAPass(const ResourcePtr & resource)140 bool RewriterBeforeOptAPass(const ResourcePtr &resource) {
141   MS_EXCEPTION_IF_NULL(resource);
142   FuncGraphPtr func_graph = resource->func_graph();
143   MS_EXCEPTION_IF_NULL(func_graph);
144   (void)opt::RewriterBeforeOptA(func_graph, resource->manager());
145   UpdateArgsSpec(func_graph, resource);
146   return true;
147 }
148 
TransformTopGraphPass(const ResourcePtr & resource)149 bool TransformTopGraphPass(const ResourcePtr &resource) {
150   MS_EXCEPTION_IF_NULL(resource);
151   if (resource->func_graph() == nullptr) {
152     MS_LOG(INTERNAL_EXCEPTION) << "Transform top graph error.";
153   }
154   FuncGraphPtr func_graph = resource->func_graph();
155   if (opt::FuncGraphHasSequenceInput(func_graph)) {
156     opt::GraphSequenceParamTransform graph_trans;
157     func_graph = graph_trans(func_graph, resource->manager());
158     resource->set_func_graph(func_graph);
159     AbstractBasePtrList abs_spec_list;
160     auto &params = func_graph->parameters();
161     (void)std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list),
162                          [](const AnfNodePtr &node) { return node->abstract(); });
163     resource->set_args_abs(abs_spec_list);
164   }
165   return true;
166 }
167 
RewriterAfterOptAPass(const ResourcePtr & resource)168 bool RewriterAfterOptAPass(const ResourcePtr &resource) {
169   MS_EXCEPTION_IF_NULL(resource);
170   FuncGraphPtr func_graph = resource->func_graph();
171   MS_EXCEPTION_IF_NULL(func_graph);
172   (void)opt::RewriterAfterOptA(func_graph, resource);
173   UpdateArgsSpec(func_graph, resource);
174   return true;
175 }
176 
ConvertAfterRewriterPass(const ResourcePtr & resource)177 bool ConvertAfterRewriterPass(const ResourcePtr &resource) {
178   MS_EXCEPTION_IF_NULL(resource);
179   FuncGraphPtr func_graph = resource->func_graph();
180   MS_EXCEPTION_IF_NULL(func_graph);
181   (void)opt::ConvertAfterRewriter(func_graph, resource);
182   UpdateArgsSpec(func_graph, resource);
183   return true;
184 }
185 
OrderPyExecuteAfterRewriterPass(const ResourcePtr & resource)186 bool OrderPyExecuteAfterRewriterPass(const ResourcePtr &resource) {
187   MS_EXCEPTION_IF_NULL(resource);
188   FuncGraphPtr func_graph = resource->func_graph();
189   MS_EXCEPTION_IF_NULL(func_graph);
190   (void)opt::OrderPyExecuteAfterRewriter(func_graph, resource);
191   UpdateArgsSpec(func_graph, resource);
192   return true;
193 }
194 
PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib & irpass,const ResourcePtr & resource)195 FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource) {
196   MS_EXCEPTION_IF_NULL(resource);
197   MS_EXCEPTION_IF_NULL(resource->func_graph());
198   opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
199     irpass.pynative_eliminate_,
200   });
201 
202   opt::OptPassConfig switch_simplify = opt::OptPassConfig({
203     irpass.switch_simplify_,
204   });
205 
206   opt::OptPassConfig inline_opt = opt::OptPassConfig({
207     irpass.inline_,
208   });
209 
210   OptPassGroupMap map(
211     {{"ad_eliminate", pynative_eliminate}, {"ad_inline", inline_opt}, {"ad_switch_simplify", switch_simplify}});
212 
213   auto prim_bprop_opt_step_1 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_1", resource, map);
214   FuncGraphPtr func_graph = resource->func_graph();
215   ProfileExecute(MsProfile::GetProfile()->Step("prim_bprop_opt_step_1"), [&prim_bprop_opt_step_1, &func_graph]() {
216     func_graph = prim_bprop_opt_step_1->step(func_graph, true);
217   });
218   return func_graph;
219 }
220 
PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib & irpass,const ResourcePtr & resource,const std::vector<bool> & need_grad_flags)221 FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &resource,
222                                 const std::vector<bool> &need_grad_flags) {
223   MS_EXCEPTION_IF_NULL(resource);
224   MS_EXCEPTION_IF_NULL(resource->func_graph());
225   OptPassGroupMap map;
226 
227   opt::OptPassConfig special_op_simplify = opt::OptPassConfig({
228     irpass.switch_simplify_,
229     irpass.reduce_eliminate_,
230     irpass.tile_eliminate_,
231     irpass.arithmetic_simplify_,
232   });
233 
234   opt::OptPassConfig inline_opt = opt::OptPassConfig({
235     irpass.inline_,
236   });
237 
238   auto re_auto_monadwrapper = [](const FuncGraphPtr &root, const opt::OptimizerPtr &) -> bool {
239     return ReAutoMonad(root);
240   };
241 
242   map.push_back({"ad_renormalize", opt::OptPassConfig::Renormalize()});
243   map.push_back({"ad_inline", inline_opt});
244   map.push_back({"ad_special_op_simplify", special_op_simplify});
245   map.push_back({"auto_monad_grad", opt::OptPassConfig(re_auto_monadwrapper)});
246   if (!need_grad_flags.empty()) {
247     // If func graph has not need_grad_flag_of_inputs attr, this graph has no need do this pass.
248     opt::OptPassConfig pynative_no_grad_eliminate = opt::OptPassConfig({
249       irpass.pynative_no_grad_eliminate_,
250     });
251 
252     map.push_back({"pynative_no_grad_eliminate", pynative_no_grad_eliminate});
253   }
254 
255   auto prim_bprop_opt_step_2 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_2", resource, map);
256   FuncGraphPtr func_graph = resource->func_graph();
257   ProfileExecute(MsProfile::GetProfile()->Step("prim_bprop_opt_step_2"), [&prim_bprop_opt_step_2, &func_graph]() {
258     func_graph = prim_bprop_opt_step_2->step(func_graph, true);
259   });
260   return func_graph;
261 }
262 
JitBpropGraphPass(const ResourcePtr & resource,bool need_renormalize)263 FuncGraphPtr JitBpropGraphPass(const ResourcePtr &resource, bool need_renormalize) {
264   opt::irpass::OptimizeIRPassLib irpass;
265   opt::OptPassConfig grad_graph_opt = opt::OptPassConfig({
266     irpass.inline_,
267     irpass.list_to_tuple_eliminator_,
268     irpass.tuple_to_list_eliminator_,
269     irpass.tuple_list_get_set_item_eliminator_,
270     irpass.tuple_list_get_item_eliminator_,
271     irpass.tuple_list_set_item_eliminator_,
272     irpass.depend_value_elim_,
273     irpass.reshape_eliminate_,
274     irpass.switch_simplify_,
275     irpass.addn_zero_filter_,
276     irpass.ad_related_special_op_eliminate_,
277   });
278   opt::OptPassConfig fill_zeros_like = opt::OptPassConfig{irpass.zero_like_fill_zero_};
279   OptPassGroupMap map({
280     {"grad_graph_opt", grad_graph_opt},
281     {"zeros_like", fill_zeros_like},
282   });
283   if (need_renormalize) {
284     (void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
285     opt::OptPassConfig real_op_eliminate = opt::OptPassConfig{irpass.real_op_eliminate_};
286     (void)map.emplace_back(std::make_pair("real_op_eliminate", real_op_eliminate));
287   }
288   MS_EXCEPTION_IF_NULL(resource);
289   auto func_graph = resource->func_graph();
290   auto graph_opt = opt::Optimizer::MakeOptimizer("jit_bprop_graph_opt", resource, map);
291   return graph_opt->step(func_graph, false);
292 }
293 
FinalBpropGraphPass(const ResourcePtr & resource,bool has_control_flow)294 FuncGraphPtr FinalBpropGraphPass(const ResourcePtr &resource, bool has_control_flow) {
295   MS_EXCEPTION_IF_NULL(resource);
296   auto func_graph = resource->func_graph();
297 
298   opt::irpass::OptimizeIRPassLib irpass;
299   OptPassGroupMap map;
300   opt::OptPassConfig inline_opt = opt::OptPassConfig({
301     irpass.inline_,
302   });
303   map.emplace_back("ad_inline", inline_opt);
304 
305   opt::OptPassConfig grad_graph_opt = opt::OptPassConfig({
306     irpass.tuple_list_get_item_eliminator_,
307     irpass.zero_like_fill_zero_,
308   });
309   (void)map.emplace_back("grad_graph_opt", grad_graph_opt);
310 
311   if (has_control_flow) {
312     opt::OptPassConfig env_eliminate = opt::OptPassConfig({
313       irpass.environ_get_eliminate_,
314       irpass.environ_get_add_eliminate_,
315       irpass.environ_get_set_eliminate_,
316       irpass.environ_get_depend_swap_,
317       irpass.environ_add_const_eliminate_,
318     });
319     (void)map.emplace_back(std::make_pair("env_eliminate", env_eliminate));
320   }
321   auto graph_opt = opt::Optimizer::MakeOptimizer("final_bprop_graph_opt", resource, map);
322   return graph_opt->step(func_graph, false);
323 }
324 
325 namespace {
ReAutoMonadWrapper(const FuncGraphPtr & root,const opt::OptimizerPtr &)326 bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); }
327 
parallel_mode()328 bool parallel_mode() {
329 #if defined(__linux__) && defined(WITH_BACKEND)
330   if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
331     return false;
332   }
333 #endif
334   std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
335   return (parallel_mode == parallel::kAutoParallel) || (parallel_mode == parallel::kSemiAutoParallel);
336 }
337 
AddParallelRenormalize(OptPassGroupMap * map_a)338 void AddParallelRenormalize(OptPassGroupMap *map_a) {
339   if (parallel_mode()) {
340     auto parallel_end_opt =
341       find_if(map_a->begin(), map_a->end(), [](auto opt_pair) { return opt_pair.first == "meta_fg_expand"; });
342     if (parallel_end_opt != map_a->end()) {
343       opt::irpass::OptimizeIRPassLib irpass;
344       opt::OptPassConfig cast_eliminate_pass = opt::OptPassConfig({irpass.cast_eliminate_});
345       auto iter = map_a->insert(parallel_end_opt, {"cast_eliminate", cast_eliminate_pass});
346       (void)map_a->insert(iter, {"parallel_renormalize", opt::OptPassConfig::Renormalize()});
347     }
348   }
349 }
350 
GetOptPassA1(const opt::irpass::OptimizeIRPassLib & irpass)351 opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
352   return opt::OptPassConfig({
353     irpass.partial_defer_inline_,
354     irpass.switch_defer_inline_,
355     irpass.switch_layer_defer_inline_,
356     irpass.switch_simplify_,
357     irpass.exchange_switch_depend_value_,
358     irpass.float_depend_g_call_,
359 
360     // Safe inlining
361     irpass.inline_,
362     irpass.updatestate_useless_node_eliminater_,
363     irpass.updatestate_pure_node_eliminater_,
364     irpass.load_eliminater_,
365     irpass.stopgrad_eliminater_,
366     irpass.partial_eliminate_,
367     irpass.replace_applicator_,
368     irpass.convert_tensor_eliminate_,
369 
370     // Miscellaneous
371     irpass.list_to_tuple_eliminator_,
372     irpass.tuple_to_list_eliminator_,
373     irpass.tuple_list_get_item_eliminator_,
374     irpass.make_slice_get_slice_eliminator_,
375     irpass.tuple_list_get_item_const_eliminator_,
376     irpass.tuple_list_set_item_eliminator_,
377     irpass.tuple_list_get_set_item_eliminator_,
378     irpass.tuple_list_get_item_depend_reorder_,
379     irpass.tuple_list_convert_item_index_to_positive_,
380     irpass.dict_get_item_eliminator_,
381     irpass.dict_get_item_const_eliminator_,
382     irpass.dict_set_item_eliminator_,
383 
384     irpass.environ_get_eliminate_,
385     irpass.environ_get_add_eliminate_,
386     irpass.environ_get_set_eliminate_,
387     irpass.environ_get_depend_swap_,
388     irpass.environ_add_const_eliminate_,
389 
390     irpass.cast_eliminate_,
391     irpass.reshape_eliminate_,
392     irpass.reduce_eliminate_,
393     irpass.tile_eliminate_,
394     irpass.transpose_eliminate_,
395     irpass.minmaximum_grad_,
396 
397     // Arithmetic simplifications
398     irpass.arithmetic_simplify_,
399     irpass.addn_zero_filter_,
400     irpass.adjust_all_reduce_mul_add_,
401     irpass.accumulaten_eliminater_,
402 
403     // Safe inlining
404     irpass.inline_,
405     irpass.updatestate_useless_node_eliminater_,
406     irpass.updatestate_pure_node_eliminater_,
407     irpass.load_eliminater_,
408     irpass.stopgrad_eliminater_,
409     irpass.print_const_string_wrapper_,
410   });
411 }
412 
FlashSPFrontPass(const FuncGraphPtr & func_graph,const opt::OptimizerPtr & optimizer)413 bool FlashSPFrontPass(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer) {
414   if (func_graph->has_flag(parallel::FLASH_SP_RUN_ONCE_ONLY)) {
415     return false;
416   }
417   auto result = parallel::SetFlashSP(func_graph);
418   func_graph->set_flag(parallel::FLASH_SP_RUN_ONCE_ONLY, true);
419   return result;
420 }
421 
GetOptPassesA(const opt::irpass::OptimizeIRPassLib & irpass)422 OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
423   opt::OptPassConfig a_1 = GetOptPassA1(irpass);
424   opt::OptPassConfig a_2 = opt::OptPassConfig(
425     {
426       irpass.switch_simplify_,
427       irpass.specialize_transform_,
428       irpass.merge_addn_,
429       irpass.compare_switch_simplify_,
430       irpass.addn_check_dump_,
431       irpass.float_tuple_getitem_switch_,
432       irpass.float_environ_get_switch_,
433       irpass.inline_,
434       irpass.updatestate_useless_node_eliminater_,
435       irpass.arithmetic_simplify_,
436       irpass.tuple_list_set_item_eliminator_,
437       irpass.tuple_list_get_item_eliminator_,
438       irpass.incorporate_call_,
439       irpass.incorporate_call_switch_,
440       irpass.environ_get_eliminate_,
441       irpass.depend_value_elim_,
442       irpass.all_reduce_const_elim_,
443     },
444     false, true);
445 
446   opt::OptPassConfig before_grad = opt::OptPassConfig({irpass.j_node_and_user_rematch_});
447 
448   opt::OptPassConfig a_after_grad = opt::OptPassConfig({irpass.inline_without_move_, irpass.stack_unstack_eliminate_});
449 
450   opt::OptPassConfig a_3 = opt::OptPassConfig(
451     {
452       irpass.same_eliminate_,
453       irpass.check_bprop_eliminate_,
454       irpass.switch_layer_defer_inline_,
455       irpass.replace_applicator_,
456       irpass.row_tensor_add_zeros_like_,
457       irpass.mini_step_allgather_replace_,
458       irpass.micro_step_allgather_replace_,
459       irpass.split_environ_get_set_with_tuple_value_,
460     },
461     false, true);
462   opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_});
463   opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
464   opt::OptPassConfig after_resolve_pass = opt::OptPassConfig({irpass.replace_old_param_});
465   // Disable after_resolve_pass if Pre-Lift is enabled.
466   static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
467   if (enable_pre_lift) {
468     after_resolve_pass.set_disabled(true);
469   }
470   opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
471   opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
472   opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
473   opt::OptPassConfig recompute_prepare = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
474   opt::OptPassConfig get_grad = opt::OptPassConfig({irpass.get_grad_eliminate_});
475   opt::OptPassConfig cell_reuse_handle_not_recompute_node_pass =
476     opt::OptPassConfig({irpass.remove_not_recompute_node_}, false, true);
477 
478   opt::OptPassConfig c_1 = opt::OptPassConfig({
479     irpass.switch_call_monad_eliminater_,
480     irpass.partial_eliminate_,
481   });
482   // Disable c_1 if Pre-Lift is not enabled.
483   if (!enable_pre_lift) {
484     c_1.set_disabled(true);
485   }
486   // Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
487   OptPassGroupMap map_a({{"expand_dump_flag", opt::OptPassConfig(opt::irpass::ExpandDumpFlag())},
488                          {"switch_simplify", opt::OptPassConfig({irpass.switch_simplify_})},
489                          {"a_1", a_1},
490                          {"recompute_prepare", recompute_prepare},
491                          {"updatestate_depend_eliminate", updatestate_depend_eliminate},
492                          {"updatestate_assign_eliminate", updatestate_assign_eliminate},
493                          {"updatestate_loads_eliminate", updatestate_loads_eliminate},
494                          {"c_1", c_1},
495                          {"parameter_eliminate", opt::OptPassConfig(opt::irpass::ParameterEliminator())},
496                          {"a_2", a_2},
497                          {"accelerated_algorithm", accelerated_algorithm},
498                          {"shard", opt::OptPassConfig(parallel::Shard)},
499                          {"meta_shard_fg_expand", opt::OptPassConfig(opt::irpass::ExpandMetaShardFg())},
500                          {"shard_inline", opt::OptPassConfig({irpass.inline_})},
501                          {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)},
502                          {"parallel", opt::OptPassConfig(parallel::StepParallel)},
503                          {"flash_sp", opt::OptPassConfig(FlashSPFrontPass)},
504                          {"merge_comm", opt::OptPassConfig(parallel::MergeComm)},
505                          {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
506                          {"matmul_add_comm_reduction", opt::OptPassConfig(parallel::MatmulAddCommReduction)},
507                          {"virtual_shard_identity", opt::OptPassConfig({irpass.virtual_shard_identity_})},
508                          {"virtual_dataset", virtual_dataset},
509                          {"get_grad_eliminate_", get_grad},
510                          {"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
511                          {"merge_forward", opt::OptPassConfig(ad::MergeForward)},
512                          {"cell_reuse_recompute_pass", opt::OptPassConfig(opt::irpass::AddRecomputeNodes)},
513                          {"cell_reuse_handle_not_recompute_node_pass", cell_reuse_handle_not_recompute_node_pass},
514                          {"before_grad", before_grad},
515                          {"meta_fg_expand", opt::OptPassConfig(opt::irpass::ExpandMetaFg())},
516                          {"receive_attached", opt::OptPassConfig(parallel::IsolatedNodeAttach)},
517                          {"after_resolve", after_resolve_pass},
518                          {"a_after_grad", a_after_grad},
519                          {"renormalize", opt::OptPassConfig::Renormalize()},
520                          {"real_op_eliminate", opt::OptPassConfig({irpass.real_op_eliminate_})},
521                          {"add_forward_monad_depend", opt::OptPassConfig(opt::irpass::AddForwardMonadDepend)},
522                          {"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)},
523                          {"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
524                          {"cse", opt::OptPassConfig(opt::CSEPass(false))},
525                          {"a_3", a_3}});
526   AddParallelRenormalize(&map_a);
527   return map_a;
528 }
529 
GetA1A2(const opt::irpass::OptimizeIRPassLib & irpass)530 OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
531   auto opt_a = GetOptPassesA(irpass);
532   constexpr auto a1_a2_len = 10;
533   OptPassGroupMap a1_a2(opt_a.begin(), opt_a.begin() + a1_a2_len);
534   return a1_a2;
535 }
536 
GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib & irpass)537 OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
538   opt::OptPassConfig c_1 = opt::OptPassConfig({
539     // Safe inlining,
540     irpass.inline_,
541     irpass.updatestate_useless_node_eliminater_,
542     irpass.updatestate_pure_node_eliminater_,
543     irpass.load_eliminater_,
544     irpass.switch_call_monad_eliminater_,
545     irpass.stopgrad_eliminater_,
546     irpass.partial_eliminate_,
547     irpass.slice_to_tuple_,
548   });
549   opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
550   opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
551   opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
552 
553   OptPassGroupMap map_a({{"c_1", c_1},
554                          {"parameter_eliminate", opt::OptPassConfig(opt::irpass::ParameterEliminator())},
555                          {"updatestate_depend_eliminate", updatestate_depend_eliminate},
556                          {"updatestate_assign_eliminate", updatestate_assign_eliminate},
557                          {"updatestate_loads_eliminate", updatestate_loads_eliminate},
558                          {"cse", opt::OptPassConfig(opt::CSEPass(false))},
559                          {"renormalize", opt::OptPassConfig::Renormalize()}});
560 
561   return map_a;
562 }
563 
GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib & irpass)564 OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
565   opt::OptPassConfig d_1 = opt::OptPassConfig({
566     irpass.call_graph_tuple_transform_,
567     irpass.list_to_tuple_eliminator_,
568     irpass.tuple_to_list_eliminator_,
569     irpass.tuple_list_get_item_eliminator_,
570     irpass.tuple_list_get_item_const_eliminator_,
571     irpass.tuple_list_set_item_eliminator_,
572     irpass.tuple_list_get_set_item_eliminator_,
573     irpass.tuple_list_get_item_depend_reorder_,
574     irpass.tuple_list_convert_item_index_to_positive_,
575   });
576 
577   OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
578 
579   return map_a;
580 }
581 
GetOptPassesB(const opt::irpass::OptimizeIRPassLib & irpass)582 OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
583   opt::OptPassConfig b_1 = opt::OptPassConfig({irpass.zero_like_fill_zero_,
584                                                irpass.list_to_tuple_eliminator_,
585                                                irpass.tuple_to_list_eliminator_,
586                                                irpass.tuple_list_get_item_eliminator_,
587                                                irpass.tuple_list_get_item_const_eliminator_,
588                                                irpass.tuple_list_set_item_eliminator_,
589                                                irpass.tuple_list_get_set_item_eliminator_,
590                                                irpass.tuple_list_get_item_depend_reorder_,
591                                                irpass.tuple_list_convert_item_index_to_positive_,
592                                                irpass.make_slice_get_slice_eliminator_,
593                                                irpass.float_tuple_getitem_switch_,
594                                                irpass.reset_defer_inline_,
595                                                irpass.inline_,
596                                                irpass.updatestate_useless_node_eliminater_,
597                                                irpass.updatestate_pure_node_eliminater_,
598                                                irpass.load_eliminater_,
599                                                irpass.stopgrad_eliminater_,
600                                                irpass.special_op_eliminate_,
601                                                irpass.environ_get_eliminate_,
602                                                irpass.environ_get_add_eliminate_,
603                                                irpass.environ_get_set_eliminate_,
604                                                irpass.environ_get_depend_swap_,
605                                                irpass.environ_add_const_eliminate_,
606                                                irpass.value_based_eliminate_,
607                                                irpass.parallel_virtual_node_,
608                                                irpass.const_output_eliminate_},
609                                               false, true);
610   opt::OptPassConfig b_2 = opt::OptPassConfig({
611     irpass.row_tensor_eliminate_,
612   });
613   opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
614   opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
615   opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
616   OptPassGroupMap map({
617     {"b_1", b_1},
618     {"b_2", b_2},
619     {"updatestate_depend_eliminate", updatestate_depend_eliminate},
620     {"updatestate_assign_eliminate", updatestate_assign_eliminate},
621     {"updatestate_loads_eliminate", updatestate_loads_eliminate},
622     {"renormalize", opt::OptPassConfig::Renormalize()},
623     {"cse", opt::OptPassConfig(opt::CSEPass(false))},
624   });
625   return map;
626 }
627 
GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib & irpass)628 OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &irpass) {
629   opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
630     irpass.pynative_eliminate_,
631   });
632 
633   OptPassGroupMap map({
634     {"pynative_eliminate", pynative_eliminate},
635   });
636   return map;
637 }
638 
GetOptPassesC(const opt::irpass::OptimizeIRPassLib &)639 OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &) {
640   return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
641 }
642 
GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib & irpass)643 OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
644   auto opt_a = GetOptPassesA(irpass);
645   auto a3 = opt_a[opt_a.size() - 1];
646   OptPassGroupMap map({
647     {"renormalize", opt::OptPassConfig::Renormalize()},
648     {"cse", opt::OptPassConfig(opt::CSEPass(false))},
649     {a3},
650   });
651   return map;
652 }
653 
GetGradPartialTransformPhases()654 OptPassGroupMap GetGradPartialTransformPhases() {
655   opt::irpass::GradPartialPassLib irpass;
656   auto grad_partial_transform = opt::OptPassConfig({irpass.grad_partial_transform_});
657   opt::OptPassGroupMap grad_partial_transform_map({{"grad_partial_transform", grad_partial_transform}});
658   return grad_partial_transform_map;
659 }
660 
GetPreparePhases(const opt::irpass::OptimizeIRPassLib & irpass)661 OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
662   opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_});
663   OptPassGroupMap map({{"prepare_group", prepare_group}});
664   return map;
665 }
666 
GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &)667 OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &) {
668   OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
669   return map;
670 }
671 
GetSymbolEngineOptPass(const opt::irpass::OptimizeIRPassLib & irpass)672 OptPassGroupMap GetSymbolEngineOptPass(const opt::irpass::OptimizeIRPassLib &irpass) {
673   if (common::GetEnv("MS_SYMBOL_ENGINE_OPTIMIZE") == "off") {
674     MS_LOG(INFO) << "SymbolEngineOptimizer is disabled.";
675     return OptPassGroupMap();
676   }
677   OptPassGroupMap map({{"build", opt::OptPassConfig(opt::irpass::SymbolEngineBuilder())},
678                        {"elim_shapecalc", opt::OptPassConfig({irpass.elim_shapecalc_of_broadcastargs_})},
679                        {"elim_not_effective", opt::OptPassConfig({irpass.elim_not_effective_node_})},
680                        {"opt_reshape", opt::OptPassConfig({irpass.opt_reshape_})},
681                        {"fold_const_symbol", opt::OptPassConfig({irpass.fold_const_symbol_})},
682                        {"shape_op_cse", opt::OptPassConfig(opt::irpass::ShapeOpCse())},
683                        {"renormalize", opt::OptPassConfig::Renormalize()}});
684   return map;
685 }
686 
687 static mindspore::HashMap<std::string, std::shared_ptr<Optimizer>> g_pass_opts = {};
688 
InitOpt(const ResourcePtr & resource)689 void InitOpt(const ResourcePtr &resource) {
690   if (g_pass_opts.size() == 0) {
691     opt::irpass::OptimizeIRPassLib irpass;
692     g_pass_opts["a1a2"] = Optimizer::MakeOptimizer("a1a2", resource, GetA1A2(irpass));
693     g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", resource, GetOptPassesA(irpass));
694     g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", resource, GetOptPassesB(irpass), false, true);
695     g_pass_opts["opt_after_cconv"] =
696       Optimizer::MakeOptimizer("opt_after_cconv", resource, GetOptPassesAfterCconv(irpass), false, true);
697     g_pass_opts["opt_trans_graph"] =
698       Optimizer::MakeOptimizer("opt_trans_graph", resource, GetOptPassesTransformGraph(irpass), true, true);
699     g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", resource, GetOptPassesC(irpass));
700     g_pass_opts["opt_grad_epilogue"] =
701       Optimizer::MakeOptimizer("opt_grad_epilogue", resource, GetOptPynativeGradEpiloguePhases(irpass), true, false);
702     g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", resource, GetPreparePhases(irpass));
703     g_pass_opts["opt_after_recompute"] =
704       Optimizer::MakeOptimizer("opt_after_recompute", resource, GetAfterRecomputePass(irpass));
705     g_pass_opts["symbol_engine_opt"] =
706       Optimizer::MakeOptimizer("symbol_engine_opt", resource, GetSymbolEngineOptPass(irpass), true, true);
707   }
708 }
709 }  // namespace
710 
ReclaimOptimizer()711 void ReclaimOptimizer() {
712   for (auto &opt : g_pass_opts) {
713     opt.second = nullptr;
714   }
715   g_pass_opts.clear();
716 }
717 
OptPassGroup(const ResourcePtr & resource,const std::string & name)718 bool OptPassGroup(const ResourcePtr &resource, const std::string &name) {
719   MS_EXCEPTION_IF_NULL(resource);
720   if (resource->func_graph() == nullptr) {
721     MS_LOG(ERROR) << "Opt passes int64_t error";
722     return false;
723   }
724 
725   FuncGraphPtr func_graph = resource->func_graph();
726   MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", "
727                 << func_graph->get_return()->DebugString(true);
728   InitOpt(resource);
729   if (g_pass_opts.find(name) != g_pass_opts.end()) {
730     resource->set_func_graph(g_pass_opts[name]->step(func_graph));
731   }
732   // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to
733   // resource->args_abs_ yet. So if any later pass or action want to use that variable, it should be set here.
734   return true;
735 }
736 
OptPassA1A2(const ResourcePtr & resource)737 bool OptPassA1A2(const ResourcePtr &resource) { return OptPassGroup(resource, "a1a2"); }
OptPassAGroup(const ResourcePtr & resource)738 bool OptPassAGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_a"); }
OptPassBGroup(const ResourcePtr & resource)739 bool OptPassBGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_b"); }
OptPassAfterCconvGroup(const ResourcePtr & resource)740 bool OptPassAfterCconvGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_after_cconv"); }
OptPassTransformGraphGroup(const ResourcePtr & resource)741 bool OptPassTransformGraphGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_trans_graph"); }
ControlGroup(const ResourcePtr & resource)742 bool ControlGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_control"); }
PrepareGroup(const ResourcePtr & resource)743 bool PrepareGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_prepare"); }
OptAfterRecomputeGroup(const ResourcePtr & resource)744 bool OptAfterRecomputeGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_after_recompute"); }
745 
OptPassRNGroup(const ResourcePtr & resource)746 bool OptPassRNGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "renormal"); }
SymEngOptGroup(const ResourcePtr & resource)747 bool SymEngOptGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "symbol_engine_opt"); }
748 
OptPassGradEpilogueGroup(const ResourcePtr & resource)749 bool OptPassGradEpilogueGroup(const ResourcePtr &resource) { return OptPassGroup(resource, "opt_grad_epilogue"); }
750 
AddRecomputationPass(const ResourcePtr & resource)751 bool AddRecomputationPass(const ResourcePtr &resource) {
752   auto context = MsContext::GetInstance();
753   MS_EXCEPTION_IF_NULL(context);
754   if (context->CellReuseLevel() != CellReuseLevel::kNoCellReuse) {
755     return true;
756   }
757   MS_EXCEPTION_IF_NULL(resource);
758   opt::InsertRecomputedNodes(resource->func_graph());
759   return true;
760 }
761 
SliceRecomputeActivationPass(const ResourcePtr & resource)762 bool SliceRecomputeActivationPass(const ResourcePtr &resource) {
763   MS_EXCEPTION_IF_NULL(resource);
764   opt::SliceRecomputedActivationNodes(resource->func_graph());
765   return true;
766 }
767 
GroupedPairwiseExchangeAllToAllPass(const ResourcePtr & resource)768 bool GroupedPairwiseExchangeAllToAllPass(const ResourcePtr &resource) {
769   MS_EXCEPTION_IF_NULL(resource);
770   opt::SetGroupedPairwiseExchangeAllToAll(resource);
771   return true;
772 }
773 
SliceReuseRecomputedActivationPass(const ResourcePtr & resource)774 bool SliceReuseRecomputedActivationPass(const ResourcePtr &resource) {
775   MS_EXCEPTION_IF_NULL(resource);
776   parallel::SliceReuseRecomputedActivationNodes(resource->func_graph());
777   return true;
778 }
779 
LabelMicroInterleavedIndexPass(const ResourcePtr & resource)780 bool LabelMicroInterleavedIndexPass(const ResourcePtr &resource) {
781   MS_EXCEPTION_IF_NULL(resource);
782   parallel::LabelMicroInterleavedIndex(resource->func_graph());
783   return true;
784 }
785 
OverlapRecomputeAllGatherAndFlashAttentionGradPass(const ResourcePtr & resource)786 bool OverlapRecomputeAllGatherAndFlashAttentionGradPass(const ResourcePtr &resource) {
787   MS_EXCEPTION_IF_NULL(resource);
788   parallel::OverlapRecomputeAllGatherAndFlashAttentionGrad(resource->func_graph());
789   return true;
790 }
791 
OptimizeParallelAllGatherCommPass(const ResourcePtr & resource)792 bool OptimizeParallelAllGatherCommPass(const ResourcePtr &resource) {
793   MS_EXCEPTION_IF_NULL(resource);
794   parallel::OptimizeParallelAllGatherComm(resource->func_graph());
795   return true;
796 }
797 
LabelFineGrainedInterleavedIndexPass(const ResourcePtr & resource)798 bool LabelFineGrainedInterleavedIndexPass(const ResourcePtr &resource) {
799   MS_EXCEPTION_IF_NULL(resource);
800   parallel::LabelFineGrainedInterleavedIndex(resource->func_graph());
801   return true;
802 }
803 
AssignAddOpt(const ResourcePtr & resource)804 bool AssignAddOpt(const ResourcePtr &resource) {
805   MS_EXCEPTION_IF_NULL(resource);
806   FuncGraphPtr func_graph = resource->func_graph();
807   MS_EXCEPTION_IF_NULL(func_graph);
808   parallel::AssignAddOpt(func_graph);
809   auto ms_context = MsContext::GetInstance();
810   auto enable_concat_eliminate = ms_context->get_param<bool>(MS_CTX_ENABLE_CONCAT_ELIMINATE_OPT);
811   if (!enable_concat_eliminate) {
812     return true;
813   }
814   OptPassGroupMap map({{"renormalize", opt::OptPassConfig({opt::OptPassConfig::Renormalize()})}});
815   auto renormalize = opt::Optimizer::MakeOptimizer("renormalize", resource, map);
816   (void)renormalize->step(func_graph, false);
817   return true;
818 }
819 
PartialUnusedArgsEliminatePass(const ResourcePtr & resource)820 bool PartialUnusedArgsEliminatePass(const ResourcePtr &resource) {
821   MS_EXCEPTION_IF_NULL(resource);
822   FuncGraphPtr func_graph = resource->func_graph();
823   auto opt = opt::irpass::PartialUnusedArgsEliminate();
824   auto changed = opt(func_graph);
825   if (changed) {
826     OptPassGroupMap map({{"renormalize", opt::OptPassConfig({opt::OptPassConfig::Renormalize()})}});
827     auto renormalize = opt::Optimizer::MakeOptimizer("renormalize", resource, map);
828     (void)renormalize->step(func_graph, false);
829   }
830   return true;
831 }
832 
MergeCastOpt(const ResourcePtr & resource)833 bool MergeCastOpt(const ResourcePtr &resource) {
834   MS_EXCEPTION_IF_NULL(resource);
835   parallel::MergeCastOpt(resource->func_graph());
836   return true;
837 }
838 
ForceFp32Comm(const ResourcePtr & resource)839 bool ForceFp32Comm(const ResourcePtr &resource) {
840   MS_EXCEPTION_IF_NULL(resource);
841   parallel::Float32Redistribution(resource->func_graph());
842   return true;
843 }
844 
SwapDpAllReduceReduceScatterPass(const ResourcePtr & resource)845 bool SwapDpAllReduceReduceScatterPass(const ResourcePtr &resource) {
846   MS_EXCEPTION_IF_NULL(resource);
847   parallel::SwapDpAllreduceReduceScatter(resource->func_graph());
848   return true;
849 }
850 
RemoveCastBeforeAssignAdd(const ResourcePtr & resource)851 bool RemoveCastBeforeAssignAdd(const ResourcePtr &resource) {
852   MS_EXCEPTION_IF_NULL(resource);
853   parallel::RemoveCastBeforeAssignAdd(resource->func_graph());
854   return true;
855 }
856 
BiasAddCommSwap(const ResourcePtr & resource)857 bool BiasAddCommSwap(const ResourcePtr &resource) {
858   MS_EXCEPTION_IF_NULL(resource);
859   parallel::BiasAddCommSwap(resource->func_graph());
860   return true;
861 }
862 
ReorderSendRecvBetweenFpBpPass(const ResourcePtr & resource)863 bool ReorderSendRecvBetweenFpBpPass(const ResourcePtr &resource) {
864   MS_EXCEPTION_IF_NULL(resource);
865   parallel::ReorderSendRecvBetweenFpBp(resource->func_graph());
866   return true;
867 }
868 
CompCommSchedulingPass(const ResourcePtr & resource)869 bool CompCommSchedulingPass(const ResourcePtr &resource) {
870   MS_EXCEPTION_IF_NULL(resource);
871   opt::CompCommScheduling(resource->func_graph());
872   return true;
873 }
874 
MicroInterLeavedOrderControlPass(const ResourcePtr & resource)875 bool MicroInterLeavedOrderControlPass(const ResourcePtr &resource) {
876   MS_EXCEPTION_IF_NULL(resource);
877   parallel::MicroInterleavedOrderControl(resource->func_graph());
878   return true;
879 }
880 
OverlapGradCommPass(const ResourcePtr & resource)881 bool OverlapGradCommPass(const ResourcePtr &resource) {
882   MS_EXCEPTION_IF_NULL(resource);
883   parallel::OverlapGradComm(resource->func_graph());
884   return true;
885 }
886 
FullMicroInterLeavedOrderControlPass(const ResourcePtr & resource)887 bool FullMicroInterLeavedOrderControlPass(const ResourcePtr &resource) {
888   MS_EXCEPTION_IF_NULL(resource);
889   parallel::FullMicroInterleavedOrderControl(resource->func_graph());
890   return true;
891 }
892 
SplitMatmulCommElementwiseOpFpPass(const ResourcePtr & resource)893 bool SplitMatmulCommElementwiseOpFpPass(const ResourcePtr &resource) {
894   MS_EXCEPTION_IF_NULL(resource);
895   parallel::SplitMatmulCommElementwiseFp(resource->func_graph());
896   return true;
897 }
898 
SplitLayerNormCommFpPass(const ResourcePtr & resource)899 bool SplitLayerNormCommFpPass(const ResourcePtr &resource) {
900   MS_EXCEPTION_IF_NULL(resource);
901   parallel::SplitLayerNormCommFp(resource->func_graph());
902   return true;
903 }
904 
CommOpAddAttrs(const ResourcePtr & resource)905 bool CommOpAddAttrs(const ResourcePtr &resource) {
906   MS_EXCEPTION_IF_NULL(resource);
907   opt::CommOpAttrs(resource->func_graph());
908   return true;
909 }
910 
ProcessSendRecvForGE(const ResourcePtr & resource)911 bool ProcessSendRecvForGE(const ResourcePtr &resource) {
912   MS_EXCEPTION_IF_NULL(resource);
913   opt::ProcessSendRecvForGE(resource->func_graph());
914   return true;
915 }
916 
AddCommOpReusePass(const ResourcePtr & resource)917 bool AddCommOpReusePass(const ResourcePtr &resource) {
918   MS_EXCEPTION_IF_NULL(resource);
919   opt::AddCommOpReuseTag(resource->func_graph());
920   return true;
921 }
922 
OverlapOptShardInPipelinePass(const ResourcePtr & resource)923 bool OverlapOptShardInPipelinePass(const ResourcePtr &resource) {
924   MS_EXCEPTION_IF_NULL(resource);
925   parallel::OverlapOptShardInPipeline(resource->func_graph());
926   return true;
927 }
928 
BeginEndOverlapInlinePass(const ResourcePtr & resource)929 bool BeginEndOverlapInlinePass(const ResourcePtr &resource) {
930   auto ms_context = MsContext::GetInstance();
931   auto is_enable = ms_context->get_param<bool>(MS_CTX_ENABLE_BEGIN_END_INLINE_OPT);
932   if (!is_enable) {
933     return true;
934   }
935   MS_EXCEPTION_IF_NULL(resource);
936   FuncGraphPtr func_graph = resource->func_graph();
937   MS_EXCEPTION_IF_NULL(func_graph);
938   parallel::BeginEndOverlapInlineOpt(resource->func_graph());
939   opt::irpass::OptimizeIRPassLib irpass;
940   opt::OptPassConfig get_item_eliminator_pass = opt::OptPassConfig({irpass.tuple_list_get_item_eliminator_});
941   OptPassGroupMap map({{"get_item_eliminator", get_item_eliminator_pass}});
942   auto get_item_eliminator = opt::Optimizer::MakeOptimizer("get_item_eliminator", resource, map);
943   (void)get_item_eliminator->step(func_graph, false);
944   return true;
945 }
946 
OverlapGradMatmulAndGradAllreduce(const ResourcePtr & resource)947 bool OverlapGradMatmulAndGradAllreduce(const ResourcePtr &resource) {
948   MS_EXCEPTION_IF_NULL(resource);
949   parallel::OverlapGradMatmulAndGradAllreduce(resource->func_graph());
950   return true;
951 }
952 
OverlapOptShardGradInPipelinePass(const ResourcePtr & resource)953 bool OverlapOptShardGradInPipelinePass(const ResourcePtr &resource) {
954   MS_EXCEPTION_IF_NULL(resource);
955   parallel::OverlapOptShardGradInPipeline(resource->func_graph());
956   return true;
957 }
958 
HandleGroupInfoPass(const ResourcePtr & resource)959 bool HandleGroupInfoPass(const ResourcePtr &resource) {
960   MS_EXCEPTION_IF_NULL(resource);
961   parallel::HandleGroupInfo();
962   return true;
963 }
964 
OverlapRecomputeAndGradModelParallel(const ResourcePtr & resource)965 bool OverlapRecomputeAndGradModelParallel(const ResourcePtr &resource) {
966   MS_EXCEPTION_IF_NULL(resource);
967   parallel::OverlapRecomputeAndGradModelParallel(resource->func_graph());
968   return true;
969 }
970 
AddCacheEmbeddingPass(const ResourcePtr & resource)971 bool AddCacheEmbeddingPass(const ResourcePtr &resource) {
972   MS_EXCEPTION_IF_NULL(resource);
973 #if defined(__linux__) && defined(WITH_BACKEND)
974   if (ps::PSContext::instance()->is_ps_mode()) {
975     return true;
976   }
977 #endif
978   FuncGraphPtr func_graph = resource->func_graph();
979   MS_EXCEPTION_IF_NULL(func_graph);
980 
981   parallel::AddCacheEmbedding(func_graph);
982   if (func_graph->has_flag(GRAPH_FLAG_CACHE_ENABLE)) {
983     auto params = func_graph->parameters();
984     AbstractBasePtrList args_abs_list;
985     (void)std::for_each(params.begin(), params.end(),
986                         [&args_abs_list](const AnfNodePtr &node) { args_abs_list.push_back(node->abstract()); });
987     func_graph = pipeline::Renormalize(resource, func_graph, args_abs_list);
988   }
989   return true;
990 }
991 
RemoveValueNodeDuplicationsPass(const ResourcePtr & resource)992 bool RemoveValueNodeDuplicationsPass(const ResourcePtr &resource) {
993   MS_EXCEPTION_IF_NULL(resource);
994   if (resource->func_graph() == nullptr) {
995     MS_LOG(INTERNAL_EXCEPTION) << "Remove value node duplications error.";
996   }
997   auto manager = resource->manager();
998   HashCache hash_cache;
999   HashValue hashes;
1000   // Remove duplicated value nodes across all graphs in manager
1001   const auto &node_user_map = manager->node_users();
1002   for (auto &fg : manager->func_graphs()) {
1003     auto value_nodes = fg->value_nodes();
1004     for (const auto &value_pair : value_nodes) {
1005       auto &users = node_user_map.at(value_pair.first);
1006       auto prim = GetValueNode<PrimitivePtr>(value_pair.first);
1007       if (IsPrimitiveEquals(prim, prim::kPrimUpdateState)) {
1008         continue;
1009       }
1010       // For data parallel with some parameters redundant, the allreduce will share the same value node
1011       // which will raise an error when do allreduce fusion, so the solution is to make the allreduce's value node
1012       // not be removed, if we found the fusion tag.
1013       if (users.size() == 1) {
1014         auto cnode = users.front().first->cast<CNodePtr>();
1015         if (IsPrimitiveCNode(cnode, prim::kPrimAllReduce) && cnode->size() > 1 && cnode->input(1)->isa<ValueNode>()) {
1016           auto allreduce_prim = GetCNodePrimitive(users.front().first);
1017           auto attrs = allreduce_prim->attrs();
1018           auto fusion_id = attrs.find(mindspore::parallel::FUSION);
1019           if (fusion_id != attrs.end() && GetValue<int64_t>(fusion_id->second) > 0) {
1020             continue;
1021           }
1022         }
1023       }
1024       TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
1025     }
1026   }
1027   return true;
1028 }
1029 
CconvPass(const ResourcePtr & resource)1030 bool CconvPass(const ResourcePtr &resource) {
1031   MS_EXCEPTION_IF_NULL(resource);
1032   MS_EXCEPTION_IF_NULL(resource->func_graph());
1033   FuncGraphPtr func_graph = resource->func_graph();
1034   FuncGraphPtr new_fg = LiftingClone(func_graph);
1035   resource->set_func_graph(new_fg);
1036   return true;
1037 }
1038 
PipelineSplitPass(const ResourcePtr & resource)1039 bool PipelineSplitPass(const ResourcePtr &resource) { return PipelineSplit(resource); }
1040 
ParallelVirtualDatasetPass(const ResourcePtr & resource)1041 bool ParallelVirtualDatasetPass(const ResourcePtr &resource) { return ParallelVirtualDataset(resource); }
1042 
PipelineParallelScheduler(const ResourcePtr & resource)1043 bool PipelineParallelScheduler(const ResourcePtr &resource) {
1044   MS_EXCEPTION_IF_NULL(resource);
1045   auto root = resource->func_graph();
1046   auto parallel_context = parallel::ParallelContext::GetInstance();
1047   MS_EXCEPTION_IF_NULL(parallel_context);
1048   auto parallel_mode = parallel_context->parallel_mode();
1049   if (parallel_mode != parallel::kSemiAutoParallel && parallel_mode != parallel::kAutoParallel) {
1050     MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
1051     return true;
1052   }
1053   auto is_pp_interleave = parallel_context->pipeline_interleave();
1054   auto stage_num = parallel_context->pipeline_stage_split_num();
1055   if (is_pp_interleave && stage_num > 1) {
1056     auto manager = resource->manager();
1057     auto stage = parallel::InferStage();
1058     auto pp_scheduler = parallel_context->pipeline_scheduler();
1059     std::shared_ptr<parallel::PipelineScheduler> scheduler = nullptr;
1060     if (pp_scheduler == parallel::kPipeline1F1B) {
1061       scheduler = std::make_shared<parallel::InterleavedScheduler>(manager, root, stage, stage_num);
1062     } else if (pp_scheduler == parallel::kPipelineGpipe) {
1063       scheduler = std::make_shared<parallel::GpipeInterleavedScheduler>(manager, root, stage, stage_num);
1064     } else {
1065       MS_LOG(EXCEPTION) << "Unsupported pipeline parallel scheduler: " << pp_scheduler;
1066     }
1067     scheduler->GetBorderNode();
1068     scheduler->Reorder();
1069   }
1070   opt::ProcessSendRecvForGE(root);
1071   return true;
1072 }
1073 
AutoParallelPass(const ResourcePtr & resource)1074 bool AutoParallelPass(const ResourcePtr &resource) {
1075   auto func_graph = resource->func_graph();
1076   auto opt = opt::Optimizer::MakeEmptyOptimizer(resource);
1077   return parallel::StepAutoParallel(func_graph, opt);
1078 }
1079 
AutoParallelSymbolPassWithReNormalize(const ResourcePtr & resource)1080 bool AutoParallelSymbolPassWithReNormalize(const ResourcePtr &resource) {
1081   // 1, auto parallel; 2, dynamic shape
1082   auto func_graph = resource->func_graph();
1083   if (!parallel::IsParallelDynamicShape(func_graph)) {
1084     return true;
1085   }
1086   MS_LOG(INFO) << "symbol pass for parallel begin";
1087   // must be bind with renormalize
1088   OptPassGroupMap opt_map({{"renormalize", opt::OptPassConfig::Renormalize()},
1089                            {"build", opt::OptPassConfig(opt::irpass::SymbolEngineBuilder())}});
1090   auto opt = opt::Optimizer::MakeOptimizer("parallel-infer-symbol", resource, opt_map, true);
1091   (void)opt->step(func_graph, false);
1092   MS_LOG(INFO) << "symbol pass for parallel end";
1093   return true;
1094 }
1095 
ValidatePass(const ResourcePtr & resource)1096 bool ValidatePass(const ResourcePtr &resource) {
1097   MS_EXCEPTION_IF_NULL(resource);
1098   MS_EXCEPTION_IF_NULL(resource->func_graph());
1099   FuncGraphPtr func_graph = resource->func_graph();
1100   Validate(func_graph);
1101   return true;
1102 }
1103 
GradPartialTransformPass(const ResourcePtr & resource)1104 bool GradPartialTransformPass(const ResourcePtr &resource) {
1105   MS_EXCEPTION_IF_NULL(resource);
1106   FuncGraphPtr func_graph = resource->func_graph();
1107   MS_EXCEPTION_IF_NULL(func_graph);
1108   auto grad_partial_transform_map = GetGradPartialTransformPhases();
1109   auto grad_partial_transform =
1110     opt::Optimizer::MakeOptimizer("grad_partial_transform", resource, grad_partial_transform_map);
1111   (void)grad_partial_transform->step(func_graph, false);
1112   return true;
1113 }
1114 
PynativeOptPass(const ResourcePtr & resource)1115 bool PynativeOptPass(const ResourcePtr &resource) {
1116   MS_EXCEPTION_IF_NULL(resource);
1117   FuncGraphPtr func_graph = resource->func_graph();
1118   MS_EXCEPTION_IF_NULL(func_graph);
1119   opt::irpass::OptimizeIRPassLib irpass;
1120   auto pynative_opt = GetOptPassesPynativeElim(irpass);
1121   auto pynative_opt_opt = opt::Optimizer::MakeOptimizer("pynative_opt", resource, pynative_opt);
1122   (void)pynative_opt_opt->step(func_graph, false);
1123   return true;
1124 }
1125 
EliminateSpecialOpOptPass(const ResourcePtr & resource)1126 bool EliminateSpecialOpOptPass(const ResourcePtr &resource) {
1127   MS_EXCEPTION_IF_NULL(resource);
1128   auto func_graph = resource->func_graph();
1129   MS_EXCEPTION_IF_NULL(func_graph);
1130   opt::irpass::OptimizeIRPassLib irpass;
1131   opt::OptPassConfig ad_related_special_op_eliminate = opt::OptPassConfig({
1132     irpass.ad_related_special_op_eliminate_,
1133   });
1134   opt::OptPassConfig mutable_op_eliminate = opt::OptPassConfig({
1135     irpass.mutable_op_eliminate_,
1136   });
1137   opt::OptPassConfig convert_tensor_op_eliminate = opt::OptPassConfig({
1138     irpass.convert_tensor_all_eliminate_,
1139   });
1140   OptPassGroupMap map({
1141     {"ad_related_special_op_eliminate", ad_related_special_op_eliminate},
1142     {"mutable_op_eliminate", mutable_op_eliminate},
1143     {"convert_tensor_op_eliminate", convert_tensor_op_eliminate},
1144   });
1145   auto special_op_eliminate_opt = opt::Optimizer::MakeOptimizer("special_op_eliminate", resource, map);
1146   (void)special_op_eliminate_opt->step(func_graph, false);
1147   return true;
1148 }
1149 
AutoMonadElimOptPass(const FuncGraphPtr & func_graph)1150 bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
1151   MS_EXCEPTION_IF_NULL(func_graph);
1152   MS_EXCEPTION_IF_NULL(func_graph->manager());
1153   auto resource = std::make_shared<pipeline::Resource>();
1154   resource->set_func_graph(func_graph);
1155   resource->set_manager(func_graph->manager());
1156 
1157   // opt::irpass::OptimizeIRPassLib is not used here to avoid double free problems in external calls.
1158   opt::SubstitutionPtr updatestate_useless_node_eliminater =
1159     opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateUselessNodeEliminater>(),
1160                           "updatestate_useless_node_eliminater", prim::kPrimUpdateState);
1161   opt::SubstitutionPtr updatestate_pure_node_eliminater =
1162     opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestatePureNodeEliminater>(),
1163                           "updatestate_pure_node_eliminater", prim::kPrimUpdateState);
1164 
1165   opt::OptPassConfig updatestate_eliminater = opt::OptPassConfig({
1166     updatestate_useless_node_eliminater,
1167     updatestate_pure_node_eliminater,
1168   });
1169   opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
1170   opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
1171   opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
1172   opt::OptPassGroupMap elim_map({
1173     {"updatestate_eliminater", updatestate_eliminater},
1174     {"updatestate_depend_eliminate", updatestate_depend_eliminate},
1175     {"updatestate_assign_eliminate", updatestate_assign_eliminate},
1176     {"updatestate_loads_eliminate", updatestate_loads_eliminate},
1177     {"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
1178   });
1179 
1180   auto auto_monad_elim_opt = opt::Optimizer::MakeOptimizer("auto_monad_elim", resource, elim_map);
1181   (void)auto_monad_elim_opt->step(func_graph, false);
1182   return true;
1183 }
1184 
EnvironConversionPass(const ResourcePtr & resource)1185 bool EnvironConversionPass(const ResourcePtr &resource) {
1186   MS_EXCEPTION_IF_NULL(resource);
1187   (void)opt::EnvironConversion(resource);
1188   return true;
1189 }
1190 
1191 // Build service-side graph for embedding distributed cache based on Parameter Server.
AddEmbeddingCachePass(const ResourcePtr & resource)1192 bool AddEmbeddingCachePass(const ResourcePtr &resource) {
1193   MS_EXCEPTION_IF_NULL(resource);
1194 #if ((defined ENABLE_CPU) && (!defined _WIN32) && !defined(__APPLE__))
1195   if (!ps::PSContext::instance()->cache_enable() || !distributed::cluster::ClusterContext::instance()->initialized() ||
1196       !ps::PSContext::instance()->is_server()) {
1197     return true;
1198   }
1199 
1200   FuncGraphPtr func_graph = resource->func_graph();
1201   MS_EXCEPTION_IF_NULL(func_graph);
1202   auto node = distributed::cluster::ClusterContext::instance()->node();
1203   MS_EXCEPTION_IF_NULL(node);
1204 
1205   // 1. Build service-size graph.
1206   auto node_role = distributed::cluster::ClusterContext::instance()->node_role();
1207   uint32_t worker_num = ps::PSContext::instance()->worker_num();
1208   std::shared_ptr<parallel::PsEmbeddingCacheInserter> embedding_cache_inserter =
1209     std::make_shared<parallel::PsEmbeddingCacheInserter>(func_graph, static_cast<int64_t>(node->rank_id()), node_role,
1210                                                          worker_num);
1211   if (!embedding_cache_inserter->Run()) {
1212     MS_LOG(ERROR) << "Insert ps embedding cache failed.";
1213     return false;
1214   }
1215 
1216   // 2. Renomalize: Infer shape and Set abstract for all nodes in graph.
1217   abstract::AbstractBasePtrList args_abs;
1218   auto parameters = func_graph->parameters();
1219   (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs),
1220                        [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
1221   FuncGraphPtr new_fg = Renormalize(resource, func_graph, args_abs);
1222   resource->set_func_graph(new_fg);
1223   resource->set_args_abs(args_abs);
1224 #endif
1225 
1226   return true;
1227 }
1228 
1229 std::vector<PassItem> kVmPasses = {
1230   {"py_interpret_to_execute", PyInterpretToExecutePass},
1231   {"rewriter_before_opt_a", RewriterBeforeOptAPass},
1232   {"opt_a", OptPassAGroup},
1233   {"py_interpret_to_execute_after_opt_a", PyInterpretToExecutePass},
1234   {"slice_cell_reuse_recomputed_activation", SliceReuseRecomputedActivationPass},
1235   {"rewriter_after_opt_a", RewriterAfterOptAPass},
1236   {"convert_after_rewriter", ConvertAfterRewriterPass},
1237   {"order_py_execute_after_rewriter", OrderPyExecuteAfterRewriterPass},
1238   {"opt_b", OptPassBGroup},
1239   {"optimize_parallel_all_gather_comm", OptimizeParallelAllGatherCommPass},
1240   {"cconv", CconvPass},
1241   {"opt_after_cconv", OptPassAfterCconvGroup},
1242   {"remove_dup_value", RemoveValueNodeDuplicationsPass},
1243   {"tuple_transform", OptPassTransformGraphGroup},
1244   {"partial_unused_args_eliminate", PartialUnusedArgsEliminatePass},
1245   {"add_cache_embedding", AddCacheEmbeddingPass},
1246   {"add_recomputation", AddRecomputationPass},
1247   {"cse_after_recomputation", OptAfterRecomputeGroup},
1248   {"environ_conv", EnvironConversionPass},
1249   {"swap_dp_allreduce_reducescatter", SwapDpAllReduceReduceScatterPass},
1250   {"bias_add_comm_swap", BiasAddCommSwap},
1251   {"label_micro_interleaved_index", LabelMicroInterleavedIndexPass},
1252   {"label_fine_grained_interleaved_index", LabelFineGrainedInterleavedIndexPass},
1253   {"merge_cast_opt", MergeCastOpt},
1254   {"slice_recompute_activation", SliceRecomputeActivationPass},
1255   {"micro_interleaved_order_control", MicroInterLeavedOrderControlPass},
1256   {"assign_add_opt", AssignAddOpt},
1257   {"ForceFp32Comm", ForceFp32Comm},
1258   {"remove_cast_before_assign_add", RemoveCastBeforeAssignAdd},
1259   {"full_micro_interleaved_order_control", FullMicroInterLeavedOrderControlPass},
1260   {"comp_comm_scheduling", CompCommSchedulingPass},
1261   {"reorder_send_recv_between_fp_bp", ReorderSendRecvBetweenFpBpPass},
1262   {"comm_op_add_attrs", CommOpAddAttrs},
1263   {"add_comm_op_reuse_tag", AddCommOpReusePass},
1264   {"overlap_opt_shard_in_pipeline", OverlapOptShardInPipelinePass},
1265   {"overlap_opt_shard_grad_in_pipeline", OverlapOptShardGradInPipelinePass},
1266   {"grouped_pairwise_exchange_alltoall", GroupedPairwiseExchangeAllToAllPass},
1267   {"overlap_recompute_and_grad_model_parallel", OverlapRecomputeAndGradModelParallel},
1268   {"overlap_grad_matmul_and_grad_allreduce", OverlapGradMatmulAndGradAllreduce},
1269   {"overlap_recompute_allgather_and_fa_grad", OverlapRecomputeAllGatherAndFlashAttentionGradPass},
1270   {"begin_end_overlap_inline", BeginEndOverlapInlinePass},
1271   {"overlap_grad_comm", OverlapGradCommPass},
1272   {"split_matmul_comm_elemetwise", SplitMatmulCommElementwiseOpFpPass},
1273   {"split_layernorm_comm", SplitLayerNormCommFpPass},
1274   // The pass cache hccl group, so the hccl group should be created before the pass
1275   {"handle_group_info", HandleGroupInfoPass},
1276   {"symbol_engine_optimizer", SymEngOptGroup}};
1277 
1278 std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
1279                                          {"opt_b", OptPassBGroup},
1280                                          {"cconv", CconvPass},
1281                                          {"transform_top", TransformTopGraphPass},
1282                                          {"transform_graph", OptPassTransformGraphGroup}};
1283 
1284 std::vector<PassItem> kInlinePasses = {{"rewriter_before_opt_a", RewriterBeforeOptAPass}, {"a1a2", OptPassA1A2}};
1285 }  // namespace pipeline
1286 }  // namespace mindspore
1287