• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/pass.h"
18 
19 #include <memory>
20 #include <vector>
21 #include <string>
22 #include <unordered_map>
23 #include <algorithm>
24 
25 #include "ir/func_graph_cloner.h"
26 #include "pipeline/jit/parse/parse_base.h"
27 #include "pipeline/jit/resource.h"
28 #include "pipeline/jit/validator.h"
29 #include "pipeline/jit/remove_value_node_dup.h"
30 #include "frontend/optimizer/opt.h"
31 #include "frontend/optimizer/optimizer.h"
32 #include "frontend/optimizer/cse_pass.h"
33 #include "frontend/optimizer/clean.h"
34 #include "frontend/optimizer/irpass.h"
35 #include "frontend/optimizer/graph_transform.h"
36 #include "frontend/optimizer/auto_monad_eliminate.h"
37 #include "frontend/parallel/context.h"
38 #include "frontend/parallel/step_parallel.h"
39 #include "frontend/parallel/step_auto_parallel.h"
40 #include "frontend/parallel/cache_embedding/cache_embedding.h"
41 #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
42 #include "frontend/optimizer/recompute.h"
43 #include "utils/log_adapter.h"
44 #include "pipeline/jit/pipeline_split.h"
45 #include "pipeline/pynative/pynative_execute.h"
46 #include "pipeline/jit/static_analysis/auto_monad.h"
47 #include "frontend/optimizer/irpass/branch_culling.h"
48 #include "frontend/optimizer/irpass/gradient_eliminate.h"
49 #include "frontend/optimizer/irpass/parameter_eliminate.h"
50 #include "frontend/optimizer/irpass/updatestate_eliminate.h"
51 #if ((defined ENABLE_CPU) && (!defined _WIN32))
52 #include "ps/util.h"
53 #include "ps/ps_context.h"
54 #endif
55 
56 namespace mindspore {
57 namespace pipeline {
58 using OptPassGroupMap = opt::OptPassGroupMap;
59 using Optimizer = opt::Optimizer;
60 using CompileGraphs = compile::CompileGraphs;
61 using abstract::AnalysisResult;
62 using mindspore::abstract::AnalysisContextPtr;
63 using mindspore::validator::Validate;
64 namespace {
DoRenormalize(const bool & changed,const FuncGraphPtr & func_graph,const ResourcePtr & res)65 void DoRenormalize(const bool &changed, const FuncGraphPtr &func_graph, const ResourcePtr &res) {
66   MS_EXCEPTION_IF_NULL(func_graph);
67   MS_EXCEPTION_IF_NULL(res);
68   abstract::AbstractBasePtrList args_spec;
69   auto parameters = func_graph->parameters();
70   (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
71                        [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
72   if (changed) {
73     FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
74     res->set_func_graph(new_fg);
75   }
76   res->set_args_spec(args_spec);
77 }
78 }  // namespace
79 
SimplifyDataStructuresPass(const ResourcePtr & res)80 bool SimplifyDataStructuresPass(const ResourcePtr &res) {
81   MS_EXCEPTION_IF_NULL(res);
82   FuncGraphPtr func_graph = res->func_graph();
83   MS_EXCEPTION_IF_NULL(func_graph);
84   bool changed = opt::SimplifyDataStructures(func_graph, res->manager());
85   DoRenormalize(changed, func_graph, res);
86   return true;
87 }
88 
TransformTopGraphPass(const ResourcePtr & res)89 bool TransformTopGraphPass(const ResourcePtr &res) {
90   MS_EXCEPTION_IF_NULL(res);
91   if (res->func_graph() == nullptr) {
92     MS_LOG(EXCEPTION) << "Transform top graph error.";
93   }
94   FuncGraphPtr func_graph = res->func_graph();
95   if (opt::FuncGraphHasTupleInput(func_graph)) {
96     opt::GraphTupleParamTransform graph_trans;
97     func_graph = graph_trans(func_graph, res->manager());
98     res->set_func_graph(func_graph);
99     AbstractBasePtrList abs_spec_list;
100     auto &params = func_graph->parameters();
101     std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list),
102                    [](const AnfNodePtr &node) { return node->abstract(); });
103     res->set_args_spec(abs_spec_list);
104   }
105   return true;
106 }
107 
CleanAfterOptAPass(const ResourcePtr & res)108 bool CleanAfterOptAPass(const ResourcePtr &res) {
109   MS_EXCEPTION_IF_NULL(res);
110   FuncGraphPtr func_graph = res->func_graph();
111   MS_EXCEPTION_IF_NULL(func_graph);
112   bool changed = opt::CleanAfterOptA(func_graph, res->manager());
113   DoRenormalize(changed, func_graph, res);
114   return true;
115 }
116 
PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib & irpass,const ResourcePtr & res)117 FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
118   MS_EXCEPTION_IF_NULL(res);
119   MS_EXCEPTION_IF_NULL(res->func_graph());
120   opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
121     irpass.pynative_eliminate_,
122   });
123 
124   opt::OptPassConfig switch_simplify = opt::OptPassConfig({
125     irpass.switch_simplify_,
126   });
127 
128   opt::OptPassConfig inline_opt = opt::OptPassConfig({
129     irpass.inline_,
130   });
131 
132   OptPassGroupMap map(
133     {{"ad_eliminate", pynative_eliminate}, {"ad_inline", inline_opt}, {"ad_switch_simplify", switch_simplify}});
134 
135   auto prim_bprop_opt_step_1 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_1", res, map);
136   FuncGraphPtr func_graph = res->func_graph();
137   WITH(MsProfile::GetProfile()->Step("prim_bprop_opt_step_1"))[&prim_bprop_opt_step_1, &func_graph]() {
138     func_graph = prim_bprop_opt_step_1->step(func_graph, true);
139   };
140   return func_graph;
141 }
142 
PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib & irpass,const ResourcePtr & res)143 FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
144   MS_EXCEPTION_IF_NULL(res);
145   MS_EXCEPTION_IF_NULL(res->func_graph());
146   opt::OptPassConfig special_op_simplify = opt::OptPassConfig({
147     irpass.switch_simplify_,
148     irpass.reduce_eliminate_,
149     irpass.tile_eliminate_,
150     irpass.arithmetic_simplify_,
151   });
152 
153   opt::OptPassConfig inline_opt = opt::OptPassConfig({
154     irpass.inline_,
155   });
156 
157   auto re_auto_monadwrapper = [](const FuncGraphPtr &root, const opt::OptimizerPtr &) -> bool {
158     return ReAutoMonad(root);
159   };
160   OptPassGroupMap map({{"ad_renormalize", opt::OptPassConfig::Renormalize()},
161                        {"ad_inline", inline_opt},
162                        {"ad_special_op_simplify", special_op_simplify},
163                        {"auto_monad_grad", opt::OptPassConfig(re_auto_monadwrapper)}});
164 
165   auto prim_bprop_opt_step_2 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_2", res, map);
166   FuncGraphPtr func_graph = res->func_graph();
167   WITH(MsProfile::GetProfile()->Step("prim_bprop_opt_step_2"))[&prim_bprop_opt_step_2, &func_graph]() {
168     func_graph = prim_bprop_opt_step_2->step(func_graph, true);
169   };
170   return func_graph;
171 }
172 
BpropGraphFinalOptPass(const ResourcePtr & res)173 FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res) {
174   MS_EXCEPTION_IF_NULL(res);
175   MS_EXCEPTION_IF_NULL(res->func_graph());
176   (void)TransformTopGraphPass(res);
177 
178   opt::irpass::OptimizeIRPassLib irpass;
179   opt::OptPassConfig bg_final_opt = opt::OptPassConfig({
180     irpass.inline_,
181     irpass.tuple_list_get_set_item_eliminator_,
182     irpass.tuple_list_get_item_eliminator_,
183     irpass.tuple_list_set_item_eliminator_,
184     irpass.depend_value_elim_,
185     irpass.reshape_eliminate_,
186     irpass.switch_simplify_,
187     irpass.addn_zero_filter_,
188   });
189   opt::OptPassConfig fill_zeros_like = opt::OptPassConfig{irpass.zero_like_fill_zero_};
190   OptPassGroupMap map({
191     {"ad_final_opt", bg_final_opt},
192     {"zeros_like", fill_zeros_like},
193   });
194 
195   if (pynative::PynativeExecutor::GetInstance()->grad_executor()->need_renormalize()) {
196     (void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
197     opt::OptPassConfig env_eliminate = opt::OptPassConfig({
198       irpass.incorporate_call_,
199       irpass.incorporate_call_switch_,
200       irpass.incorporate_getitem_set_,
201     });
202     (void)map.emplace_back(std::make_pair("env_eliminate", env_eliminate));
203   }
204 
205   auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", res, map);
206   FuncGraphPtr func_graph = res->func_graph();
207   WITH(MsProfile::GetProfile()->Step("bprop_graph_final_opt"))[&bprop_graph_final_opt, &func_graph]() {
208     func_graph = bprop_graph_final_opt->step(func_graph, true);
209   };
210 
211   return func_graph;
212 }
213 
214 namespace {
ReAutoMonadWrapper(const FuncGraphPtr & root,const opt::OptimizerPtr &)215 bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); }
216 
parallel_mode()217 bool parallel_mode() {
218 #if ((defined ENABLE_CPU) && (!defined _WIN32))
219   if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
220     return false;
221   }
222 #endif
223   std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
224   return (parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL);
225 }
226 
AddParallelRenormalize(OptPassGroupMap * map_a)227 void AddParallelRenormalize(OptPassGroupMap *map_a) {
228   if (parallel_mode()) {
229     auto parallel_end_opt =
230       find_if(map_a->begin(), map_a->end(), [](auto opt_pair) { return opt_pair.first == "grad"; });
231     if (parallel_end_opt != map_a->end()) {
232       (void)map_a->insert(parallel_end_opt, {"parallel_renormalize", opt::OptPassConfig::Renormalize()});
233     }
234   }
235 }
236 
GetOptPassA1(const opt::irpass::OptimizeIRPassLib & irpass)237 opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
238   return opt::OptPassConfig({
239     irpass.switch_defer_inline_,
240     irpass.switch_layer_defer_inline_,
241     irpass.switch_simplify_,
242     irpass.exchange_switch_depend_value_,
243     irpass.float_depend_g_call_,
244 
245     // Safe inlining
246     irpass.inline_,
247     irpass.updatestate_useless_node_eliminater_,
248     irpass.updatestate_pure_node_eliminater_,
249     irpass.load_eliminater_,
250     irpass.stopgrad_eliminater_,
251     irpass.partial_eliminate_,
252     irpass.replace_applicator_,
253 
254     // Miscellaneous
255     irpass.tuple_list_get_item_eliminator_,
256     irpass.tuple_list_get_item_const_eliminator_,
257     irpass.tuple_list_set_item_eliminator_,
258     irpass.tuple_list_get_set_item_eliminator_,
259     irpass.tuple_list_get_item_depend_reorder_,
260     irpass.tuple_list_convert_item_index_to_positive_,
261 
262     irpass.env_get_item_eliminate_,
263     irpass.env_get_item_add_eliminate_,
264     irpass.env_get_set_item_eliminate_,
265     irpass.env_get_item_depend_swap_,
266 
267     irpass.cast_eliminate_,
268     irpass.reshape_eliminate_,
269     irpass.reduce_eliminate_,
270     irpass.tile_eliminate_,
271     irpass.transpose_eliminate_,
272     irpass.minmaximum_grad_,
273     irpass.get_make_ref_eliminate_,
274 
275     // Arithmetic simplifications
276     irpass.arithmetic_simplify_,
277     irpass.addn_zero_filter_,
278     irpass.adjust_all_reduce_mul_add_,
279     irpass.accumulaten_eliminater_,
280 
281     // Safe inlining
282     irpass.inline_,
283     irpass.updatestate_useless_node_eliminater_,
284     irpass.updatestate_pure_node_eliminater_,
285     irpass.load_eliminater_,
286     irpass.stopgrad_eliminater_,
287     irpass.sparse_tensor_eliminate_,
288   });
289 }
290 
GetOptPassesA(const opt::irpass::OptimizeIRPassLib & irpass)291 OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
292   opt::OptPassConfig a_1 = GetOptPassA1(irpass);
293   opt::OptPassConfig a_2 = opt::OptPassConfig(
294     {
295       irpass.switch_simplify_,
296       irpass.specialize_transform_,
297       irpass.merge_addn_,
298       irpass.float_tuple_getitem_switch_,
299       irpass.float_env_getitem_switch_,
300       irpass.inline_,
301       irpass.incorporate_getitem_set_,
302       irpass.incorporate_call_,
303       irpass.incorporate_call_switch_,
304       irpass.incorporate_env_getitem_bypass_recursive_,
305       irpass.incorporate_env_getitem_switch_,
306       irpass.env_get_item_eliminate_,
307       irpass.depend_value_elim_,
308       irpass.all_reduce_const_elim_,
309     },
310     false, true);
311 
312   opt::OptPassConfig a_after_grad = opt::OptPassConfig({irpass.inline_without_move_});
313 
314   opt::OptPassConfig a_3 = opt::OptPassConfig(
315     {
316       irpass.arithmetic_simplify2_,
317       irpass.same_eliminate_,
318       irpass.check_bprop_eliminate_,
319       irpass.switch_layer_defer_inline_,
320       irpass.replace_applicator_,
321       irpass.mirror_mini_step_elim_,
322       irpass.virtual_add_elim_,
323       irpass.row_tensor_add_zeros_like_,
324       irpass.mini_step_allgather_replace_,
325       irpass.micro_step_allgather_replace_,
326     },
327     false, true);
328   opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_});
329   opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
330   opt::OptPassConfig after_resolve_pass =
331     opt::OptPassConfig({irpass.get_make_ref_eliminate_, irpass.replace_old_param_});
332   opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
333   opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
334   opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
335 
336   // Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
337   OptPassGroupMap map_a({{"switch_simplify", opt::OptPassConfig({irpass.switch_simplify_})},
338                          {"a_1", a_1},
339                          {"updatestate_depend_eliminate", updatestate_depend_eliminate},
340                          {"updatestate_assign_eliminate", updatestate_assign_eliminate},
341                          {"updatestate_loads_eliminate", updatestate_loads_eliminate},
342                          {"parameter_eliminate", opt::OptPassConfig(opt::irpass::ParameterEliminator())},
343                          {"a_2", a_2},
344                          {"accelerated_algorithm", accelerated_algorithm},
345                          {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)},
346                          {"parallel", opt::OptPassConfig(parallel::StepParallel)},
347                          {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
348                          {"virtual_dataset", virtual_dataset},
349                          {"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
350                          {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
351                          {"after_resolve", after_resolve_pass},
352                          {"a_after_grad", a_after_grad},
353                          {"renormalize", opt::OptPassConfig::Renormalize()},
354                          {"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)},
355                          {"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
356                          {"cse", opt::OptPassConfig(opt::CSEPass(false))},
357                          {"a_3", a_3}});
358   AddParallelRenormalize(&map_a);
359   return map_a;
360 }
361 
GetA1A2(const opt::irpass::OptimizeIRPassLib & irpass)362 OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
363   auto opt_a = GetOptPassesA(irpass);
364   constexpr auto a1_a2_len = 7;
365   OptPassGroupMap a1_a2(opt_a.begin(), opt_a.begin() + a1_a2_len);
366   return a1_a2;
367 }
368 
GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib & irpass)369 OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
370   opt::OptPassConfig c_1 = opt::OptPassConfig({
371     // Safe inlining,
372     irpass.inline_,
373     irpass.updatestate_useless_node_eliminater_,
374     irpass.updatestate_pure_node_eliminater_,
375     irpass.load_eliminater_,
376     irpass.switch_call_monad_eliminater_,
377     irpass.stopgrad_eliminater_,
378     irpass.partial_eliminate_,
379   });
380   opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
381   opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
382   opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
383 
384   OptPassGroupMap map_a({{"c_1", c_1},
385                          {"updatestate_depend_eliminate", updatestate_depend_eliminate},
386                          {"updatestate_assign_eliminate", updatestate_assign_eliminate},
387                          {"updatestate_loads_eliminate", updatestate_loads_eliminate},
388                          {"cse", opt::OptPassConfig(opt::CSEPass(false))},
389                          {"renormalize", opt::OptPassConfig::Renormalize()}});
390 
391   return map_a;
392 }
393 
GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib & irpass)394 OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
395   opt::OptPassConfig d_1 =
396     opt::OptPassConfig({irpass.call_graph_tuple_transform_, irpass.tuple_list_get_item_eliminator_,
397                         irpass.tuple_list_get_item_const_eliminator_, irpass.tuple_list_set_item_eliminator_,
398                         irpass.tuple_list_get_set_item_eliminator_, irpass.tuple_list_get_item_depend_reorder_,
399                         irpass.tuple_list_convert_item_index_to_positive_});
400 
401   OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
402 
403   return map_a;
404 }
405 
GetOptPassesB(const opt::irpass::OptimizeIRPassLib & irpass)406 OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
407   opt::OptPassConfig b_1 = opt::OptPassConfig({irpass.zero_like_fill_zero_,
408                                                irpass.tuple_list_get_item_eliminator_,
409                                                irpass.tuple_list_get_item_const_eliminator_,
410                                                irpass.tuple_list_set_item_eliminator_,
411                                                irpass.tuple_list_get_set_item_eliminator_,
412                                                irpass.tuple_list_get_item_depend_reorder_,
413                                                irpass.tuple_list_convert_item_index_to_positive_,
414                                                irpass.float_tuple_getitem_switch_,
415                                                irpass.reset_defer_inline_,
416                                                irpass.inline_,
417                                                irpass.updatestate_useless_node_eliminater_,
418                                                irpass.updatestate_pure_node_eliminater_,
419                                                irpass.load_eliminater_,
420                                                irpass.stopgrad_eliminater_,
421                                                irpass.special_op_eliminate_,
422                                                irpass.get_make_ref_eliminate_,
423                                                irpass.incorporate_env_getitem_,
424                                                irpass.incorporate_env_getitem_switch_,
425                                                irpass.env_get_item_eliminate_,
426                                                irpass.env_get_item_add_eliminate_,
427                                                irpass.env_get_set_item_eliminate_,
428                                                irpass.env_get_item_depend_swap_,
429                                                irpass.incorporate_env_getitem_switch_layer_,
430                                                irpass.value_based_eliminate_,
431                                                irpass.virtual_accu_grad_,
432                                                irpass.virtual_assign_add_,
433                                                irpass.mirror_micro_step_},
434                                               false, true);
435   opt::OptPassConfig b_2 = opt::OptPassConfig({
436     irpass.replace_refkey_by_param_,
437     irpass.make_ref_eliminate_,
438     irpass.get_ref_param_eliminate_,
439     irpass.row_tensor_eliminate_,
440   });
441   opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
442   opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
443   opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
444   OptPassGroupMap map({
445     {"b_1", b_1},
446     {"b_2", b_2},
447     {"updatestate_depend_eliminate", updatestate_depend_eliminate},
448     {"updatestate_assign_eliminate", updatestate_assign_eliminate},
449     {"updatestate_loads_eliminate", updatestate_loads_eliminate},
450     {"renormalize", opt::OptPassConfig::Renormalize()},
451     {"cse", opt::OptPassConfig(opt::CSEPass(false))},
452   });
453   return map;
454 }
455 
GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib & irpass)456 OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &irpass) {
457   opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
458     irpass.pynative_eliminate_,
459   });
460 
461   OptPassGroupMap map({
462     {"pynative_eliminate", pynative_eliminate},
463   });
464   return map;
465 }
466 
GetOptPassesC(const opt::irpass::OptimizeIRPassLib &)467 OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &) {
468   return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
469 }
470 
GetControlPhases(const opt::irpass::OptimizeIRPassLib &)471 OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &) {
472   opt::OptPassConfig control_group = opt::OptPassConfig(opt::irpass::ConvertSwitchReplacement());
473   OptPassGroupMap map({
474     {"control_group", control_group},
475     {"renormalize", opt::OptPassConfig::Renormalize()},
476   });
477   return map;
478 }
479 
GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib & irpass)480 OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
481   auto opt_a = GetOptPassesA(irpass);
482   auto a3 = opt_a[opt_a.size() - 1];
483   OptPassGroupMap map({
484     {"renormalize", opt::OptPassConfig::Renormalize()},
485     {"cse", opt::OptPassConfig(opt::CSEPass(false))},
486     {a3},
487   });
488   return map;
489 }
490 
GetInferenceOptPreparePhases()491 OptPassGroupMap GetInferenceOptPreparePhases() {
492   opt::irpass::InferenceOptPrepareLib irpass;
493   auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_});
494   opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}});
495   return prepare_map;
496 }
497 
GetPreparePhases(const opt::irpass::OptimizeIRPassLib & irpass)498 OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
499   opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_});
500   OptPassGroupMap map({{"prepare_group", prepare_group}});
501   return map;
502 }
503 
GetBeforeRecomputePass(const opt::irpass::OptimizeIRPassLib & irpass)504 OptPassGroupMap GetBeforeRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) {
505   opt::OptPassConfig set_cell_output_no_recompute = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
506   OptPassGroupMap map({{"set_cell_output_no_recompute", set_cell_output_no_recompute}});
507   return map;
508 }
509 
GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &)510 OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &) {
511   OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
512   return map;
513 }
514 
515 static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts = {};
516 
InitOpt(const ResourcePtr & res)517 void InitOpt(const ResourcePtr &res) {
518   if (g_pass_opts.size() == 0) {
519     opt::irpass::OptimizeIRPassLib irpass;
520     g_pass_opts["a1a2"] = Optimizer::MakeOptimizer("a1a2", res, GetA1A2(irpass));
521     g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
522     g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
523     g_pass_opts["opt_after_cconv"] =
524       Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
525     g_pass_opts["opt_trans_graph"] =
526       Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true);
527     g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
528     g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), true, true);
529     g_pass_opts["opt_grad_epilogue"] =
530       Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
531     g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
532     g_pass_opts["opt_before_recompute"] =
533       Optimizer::MakeOptimizer("opt_before_recompute", res, GetBeforeRecomputePass(irpass));
534     g_pass_opts["opt_after_recompute"] =
535       Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
536   }
537 }
538 }  // namespace
539 
ReclaimOptimizer()540 void ReclaimOptimizer() {
541   for (auto &opt : g_pass_opts) {
542     opt.second = nullptr;
543   }
544   g_pass_opts.clear();
545 }
546 
OptPassGroup(const ResourcePtr & res,const std::string & name)547 bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
548   MS_EXCEPTION_IF_NULL(res);
549   if (res->func_graph() == nullptr) {
550     MS_LOG(ERROR) << "Opt passes int64_t error";
551     return false;
552   }
553 
554   FuncGraphPtr func_graph = res->func_graph();
555   MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", "
556                 << func_graph->get_return()->DebugString(true);
557   InitOpt(res);
558   if (g_pass_opts.find(name) != g_pass_opts.end()) {
559     res->set_func_graph(g_pass_opts[name]->step(func_graph));
560   }
561   // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to
562   // res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here.
563   return true;
564 }
565 
OptPassA1A2(const ResourcePtr & res)566 bool OptPassA1A2(const ResourcePtr &res) { return OptPassGroup(res, "a1a2"); }
OptPassAGroup(const ResourcePtr & res)567 bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
OptPassBGroup(const ResourcePtr & res)568 bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
OptPassAfterCconvGroup(const ResourcePtr & res)569 bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
OptPassTransformGraphGroup(const ResourcePtr & res)570 bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
ControlGroup(const ResourcePtr & res)571 bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
PrepareGroup(const ResourcePtr & res)572 bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
OptBeforeRecomputeGroup(const ResourcePtr & res)573 bool OptBeforeRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_before_recompute"); }
OptAfterRecomputeGroup(const ResourcePtr & res)574 bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
575 
OptPassRNGroup(const ResourcePtr & res)576 bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
577 
OptPassGradEpilogueGroup(const ResourcePtr & res)578 bool OptPassGradEpilogueGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_grad_epilogue"); }
579 
AddRecomputationPass(const ResourcePtr & res)580 bool AddRecomputationPass(const ResourcePtr &res) {
581   MS_EXCEPTION_IF_NULL(res);
582   opt::InsertRecomputedNodes(res->func_graph());
583   return true;
584 }
585 
AddCacheEmbeddingPass(const ResourcePtr & res)586 bool AddCacheEmbeddingPass(const ResourcePtr &res) {
587   MS_EXCEPTION_IF_NULL(res);
588 #if ((defined ENABLE_CPU) && (!defined _WIN32))
589   if (ps::PSContext::instance()->is_ps_mode()) {
590     return true;
591   }
592 #endif
593   FuncGraphPtr func_graph = res->func_graph();
594   MS_EXCEPTION_IF_NULL(func_graph);
595 
596   parallel::AddCacheEmbedding(func_graph);
597   if (func_graph->has_flag(GRAPH_FLAG_CACHE_ENABLE)) {
598     auto params = func_graph->parameters();
599     AbstractBasePtrList args_spec_list;
600     std::for_each(params.begin(), params.end(),
601                   [&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); });
602     func_graph = pipeline::Renormalize(res, func_graph, args_spec_list);
603   }
604   return true;
605 }
606 
RemoveValueNodeDuplicationsPass(const ResourcePtr & res)607 bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) {
608   MS_EXCEPTION_IF_NULL(res);
609   if (res->func_graph() == nullptr) {
610     MS_LOG(EXCEPTION) << "Remove value node duplications error.";
611   }
612   auto manager = res->manager();
613   HashCache hash_cache;
614   HashValue hashes;
615   // Remove duplicated value nodes across all graphs in manager
616   auto node_user_map = manager->node_users();
617   for (auto &fg : manager->func_graphs()) {
618     auto value_nodes = fg->value_nodes();
619     for (const auto &value_pair : value_nodes) {
620       auto users = node_user_map[value_pair.first];
621       // For data parallel with some parameters redundant, the allreduce will share the same value node
622       // which will raise an error when do allreduce fusion, so the solution is to make the allreduce's value node
623       // not be removed, if we found the fusion tag.
624       if (users.size() == 1) {
625         auto cnode = users.front().first->cast<CNodePtr>();
626         if (IsPrimitiveCNode(cnode, prim::kPrimAllReduce) && cnode->inputs().size() > 1 &&
627             cnode->input(1)->isa<ValueNode>()) {
628           auto allreduce_prim = GetCNodePrimitive(users.front().first);
629           auto attrs = allreduce_prim->attrs();
630           auto fusion_id = attrs.find(mindspore::parallel::FUSION);
631           if (fusion_id != attrs.end() && GetValue<int64_t>(fusion_id->second) > 0) {
632             continue;
633           }
634         }
635       }
636       TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
637     }
638   }
639   return true;
640 }
641 
CconvPass(const ResourcePtr & res)642 bool CconvPass(const ResourcePtr &res) {
643   MS_EXCEPTION_IF_NULL(res);
644   MS_EXCEPTION_IF_NULL(res->func_graph());
645   FuncGraphPtr func_graph = res->func_graph();
646   FuncGraphPtr new_fg = LiftingClone(func_graph);
647   res->set_func_graph(new_fg);
648   return true;
649 }
650 
PipelineSplitPass(const ResourcePtr & res)651 bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); }
652 
ValidatePass(const ResourcePtr & res)653 bool ValidatePass(const ResourcePtr &res) {
654   MS_EXCEPTION_IF_NULL(res);
655   MS_EXCEPTION_IF_NULL(res->func_graph());
656   FuncGraphPtr func_graph = res->func_graph();
657   Validate(func_graph);
658   return true;
659 }
660 
InferenceOptPreparePass(const ResourcePtr & res)661 bool InferenceOptPreparePass(const ResourcePtr &res) {
662   FuncGraphPtr func_graph = res->func_graph();
663   MS_EXCEPTION_IF_NULL(func_graph);
664   auto prepare_map = GetInferenceOptPreparePhases();
665   auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map);
666   (void)infer_opt_prepare->step(func_graph, false);
667   return true;
668 }
669 
PynativeOptPass(const ResourcePtr & res)670 bool PynativeOptPass(const ResourcePtr &res) {
671   FuncGraphPtr func_graph = res->func_graph();
672   MS_EXCEPTION_IF_NULL(func_graph);
673   opt::irpass::OptimizeIRPassLib irpass;
674   auto pynative_opt = GetOptPassesPynativeElim(irpass);
675   auto pynative_opt_opt = opt::Optimizer::MakeOptimizer("pynative_opt", res, pynative_opt);
676   (void)pynative_opt_opt->step(func_graph, false);
677   return true;
678 }
679 
AutoMonadElimOptPass(const FuncGraphPtr & func_graph)680 bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
681   MS_EXCEPTION_IF_NULL(func_graph);
682   MS_EXCEPTION_IF_NULL(func_graph->manager());
683   auto res = std::make_shared<pipeline::Resource>();
684   res->set_func_graph(func_graph);
685   res->set_manager(func_graph->manager());
686 
687   // opt::irpass::OptimizeIRPassLib is not used here to avoid double free problems in external calls.
688   opt::SubstitutionPtr updatestate_useless_node_eliminater =
689     opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateUselessNodeEliminater>(),
690                           "updatestate_useless_node_eliminater", prim::kPrimUpdateState);
691   opt::SubstitutionPtr updatestate_pure_node_eliminater =
692     opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestatePureNodeEliminater>(),
693                           "updatestate_pure_node_eliminater", prim::kPrimUpdateState);
694 
695   opt::OptPassConfig updatestate_eliminater = opt::OptPassConfig({
696     updatestate_useless_node_eliminater,
697     updatestate_pure_node_eliminater,
698   });
699   opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
700   opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
701   opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
702   opt::OptPassGroupMap elim_map({
703     {"updatestate_eliminater", updatestate_eliminater},
704     {"updatestate_depend_eliminate", updatestate_depend_eliminate},
705     {"updatestate_assign_eliminate", updatestate_assign_eliminate},
706     {"updatestate_loads_eliminate", updatestate_loads_eliminate},
707     {"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
708   });
709 
710   auto auto_monad_elim_opt = opt::Optimizer::MakeOptimizer("auto_monad_elim", res, elim_map);
711   (void)auto_monad_elim_opt->step(func_graph, false);
712   return true;
713 }
714 
715 std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
716                                    {"opt_before_recompute", OptBeforeRecomputeGroup},
717                                    {"opt_a", OptPassAGroup},
718                                    {"clean_after_opta", CleanAfterOptAPass},
719                                    {"opt_b", OptPassBGroup},
720                                    {"cconv", CconvPass},
721                                    {"opt_after_cconv", OptPassAfterCconvGroup},
722                                    {"remove_dup_value", RemoveValueNodeDuplicationsPass},
723                                    {"tuple_transform", OptPassTransformGraphGroup},
724                                    {"add_cache_embedding", AddCacheEmbeddingPass},
725                                    {"add_recomputation", AddRecomputationPass},
726                                    {"cse_after_recomputation", OptAfterRecomputeGroup}};
727 
728 std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
729                                    {"opt_a", OptPassAGroup},
730                                    {"clean_after_opta", CleanAfterOptAPass},
731                                    {"opt_b", OptPassBGroup},
732                                    {"opt_control", ControlGroup},
733                                    {"opt_prepare", PrepareGroup},
734                                    {"cconv", CconvPass}};
735 
736 std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
737                                          {"opt_b", OptPassBGroup},
738                                          {"cconv", CconvPass},
739                                          {"transform_top", TransformTopGraphPass},
740                                          {"transform_graph", OptPassTransformGraphGroup}};
741 
742 std::vector<PassItem> kInlinePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"a1a2", OptPassA1A2}};
743 }  // namespace pipeline
744 }  // namespace mindspore
745