1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/pass.h"
18
19 #include <memory>
20 #include <vector>
21 #include <string>
22 #include <unordered_map>
23 #include <algorithm>
24
25 #include "ir/func_graph_cloner.h"
26 #include "pipeline/jit/parse/parse_base.h"
27 #include "pipeline/jit/resource.h"
28 #include "pipeline/jit/validator.h"
29 #include "pipeline/jit/remove_value_node_dup.h"
30 #include "frontend/optimizer/opt.h"
31 #include "frontend/optimizer/optimizer.h"
32 #include "frontend/optimizer/cse_pass.h"
33 #include "frontend/optimizer/clean.h"
34 #include "frontend/optimizer/irpass.h"
35 #include "frontend/optimizer/graph_transform.h"
36 #include "frontend/optimizer/auto_monad_eliminate.h"
37 #include "frontend/parallel/context.h"
38 #include "frontend/parallel/step_parallel.h"
39 #include "frontend/parallel/step_auto_parallel.h"
40 #include "frontend/parallel/cache_embedding/cache_embedding.h"
41 #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
42 #include "frontend/optimizer/recompute.h"
43 #include "utils/log_adapter.h"
44 #include "pipeline/jit/pipeline_split.h"
45 #include "pipeline/pynative/pynative_execute.h"
46 #include "pipeline/jit/static_analysis/auto_monad.h"
47 #include "frontend/optimizer/irpass/branch_culling.h"
48 #include "frontend/optimizer/irpass/gradient_eliminate.h"
49 #include "frontend/optimizer/irpass/parameter_eliminate.h"
50 #include "frontend/optimizer/irpass/updatestate_eliminate.h"
51 #if ((defined ENABLE_CPU) && (!defined _WIN32))
52 #include "ps/util.h"
53 #include "ps/ps_context.h"
54 #endif
55
56 namespace mindspore {
57 namespace pipeline {
58 using OptPassGroupMap = opt::OptPassGroupMap;
59 using Optimizer = opt::Optimizer;
60 using CompileGraphs = compile::CompileGraphs;
61 using abstract::AnalysisResult;
62 using mindspore::abstract::AnalysisContextPtr;
63 using mindspore::validator::Validate;
64 namespace {
DoRenormalize(const bool & changed,const FuncGraphPtr & func_graph,const ResourcePtr & res)65 void DoRenormalize(const bool &changed, const FuncGraphPtr &func_graph, const ResourcePtr &res) {
66 MS_EXCEPTION_IF_NULL(func_graph);
67 MS_EXCEPTION_IF_NULL(res);
68 abstract::AbstractBasePtrList args_spec;
69 auto parameters = func_graph->parameters();
70 (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
71 [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
72 if (changed) {
73 FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
74 res->set_func_graph(new_fg);
75 }
76 res->set_args_spec(args_spec);
77 }
78 } // namespace
79
SimplifyDataStructuresPass(const ResourcePtr & res)80 bool SimplifyDataStructuresPass(const ResourcePtr &res) {
81 MS_EXCEPTION_IF_NULL(res);
82 FuncGraphPtr func_graph = res->func_graph();
83 MS_EXCEPTION_IF_NULL(func_graph);
84 bool changed = opt::SimplifyDataStructures(func_graph, res->manager());
85 DoRenormalize(changed, func_graph, res);
86 return true;
87 }
88
TransformTopGraphPass(const ResourcePtr & res)89 bool TransformTopGraphPass(const ResourcePtr &res) {
90 MS_EXCEPTION_IF_NULL(res);
91 if (res->func_graph() == nullptr) {
92 MS_LOG(EXCEPTION) << "Transform top graph error.";
93 }
94 FuncGraphPtr func_graph = res->func_graph();
95 if (opt::FuncGraphHasTupleInput(func_graph)) {
96 opt::GraphTupleParamTransform graph_trans;
97 func_graph = graph_trans(func_graph, res->manager());
98 res->set_func_graph(func_graph);
99 AbstractBasePtrList abs_spec_list;
100 auto ¶ms = func_graph->parameters();
101 std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list),
102 [](const AnfNodePtr &node) { return node->abstract(); });
103 res->set_args_spec(abs_spec_list);
104 }
105 return true;
106 }
107
CleanAfterOptAPass(const ResourcePtr & res)108 bool CleanAfterOptAPass(const ResourcePtr &res) {
109 MS_EXCEPTION_IF_NULL(res);
110 FuncGraphPtr func_graph = res->func_graph();
111 MS_EXCEPTION_IF_NULL(func_graph);
112 bool changed = opt::CleanAfterOptA(func_graph, res->manager());
113 DoRenormalize(changed, func_graph, res);
114 return true;
115 }
116
PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib & irpass,const ResourcePtr & res)117 FuncGraphPtr PrimBpOptPassStep1(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
118 MS_EXCEPTION_IF_NULL(res);
119 MS_EXCEPTION_IF_NULL(res->func_graph());
120 opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
121 irpass.pynative_eliminate_,
122 });
123
124 opt::OptPassConfig switch_simplify = opt::OptPassConfig({
125 irpass.switch_simplify_,
126 });
127
128 opt::OptPassConfig inline_opt = opt::OptPassConfig({
129 irpass.inline_,
130 });
131
132 OptPassGroupMap map(
133 {{"ad_eliminate", pynative_eliminate}, {"ad_inline", inline_opt}, {"ad_switch_simplify", switch_simplify}});
134
135 auto prim_bprop_opt_step_1 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_1", res, map);
136 FuncGraphPtr func_graph = res->func_graph();
137 WITH(MsProfile::GetProfile()->Step("prim_bprop_opt_step_1"))[&prim_bprop_opt_step_1, &func_graph]() {
138 func_graph = prim_bprop_opt_step_1->step(func_graph, true);
139 };
140 return func_graph;
141 }
142
PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib & irpass,const ResourcePtr & res)143 FuncGraphPtr PrimBpOptPassStep2(const opt::irpass::OptimizeIRPassLib &irpass, const ResourcePtr &res) {
144 MS_EXCEPTION_IF_NULL(res);
145 MS_EXCEPTION_IF_NULL(res->func_graph());
146 opt::OptPassConfig special_op_simplify = opt::OptPassConfig({
147 irpass.switch_simplify_,
148 irpass.reduce_eliminate_,
149 irpass.tile_eliminate_,
150 irpass.arithmetic_simplify_,
151 });
152
153 opt::OptPassConfig inline_opt = opt::OptPassConfig({
154 irpass.inline_,
155 });
156
157 auto re_auto_monadwrapper = [](const FuncGraphPtr &root, const opt::OptimizerPtr &) -> bool {
158 return ReAutoMonad(root);
159 };
160 OptPassGroupMap map({{"ad_renormalize", opt::OptPassConfig::Renormalize()},
161 {"ad_inline", inline_opt},
162 {"ad_special_op_simplify", special_op_simplify},
163 {"auto_monad_grad", opt::OptPassConfig(re_auto_monadwrapper)}});
164
165 auto prim_bprop_opt_step_2 = opt::Optimizer::MakeOptimizer("prim_bprop_opt_step_2", res, map);
166 FuncGraphPtr func_graph = res->func_graph();
167 WITH(MsProfile::GetProfile()->Step("prim_bprop_opt_step_2"))[&prim_bprop_opt_step_2, &func_graph]() {
168 func_graph = prim_bprop_opt_step_2->step(func_graph, true);
169 };
170 return func_graph;
171 }
172
BpropGraphFinalOptPass(const ResourcePtr & res)173 FuncGraphPtr BpropGraphFinalOptPass(const ResourcePtr &res) {
174 MS_EXCEPTION_IF_NULL(res);
175 MS_EXCEPTION_IF_NULL(res->func_graph());
176 (void)TransformTopGraphPass(res);
177
178 opt::irpass::OptimizeIRPassLib irpass;
179 opt::OptPassConfig bg_final_opt = opt::OptPassConfig({
180 irpass.inline_,
181 irpass.tuple_list_get_set_item_eliminator_,
182 irpass.tuple_list_get_item_eliminator_,
183 irpass.tuple_list_set_item_eliminator_,
184 irpass.depend_value_elim_,
185 irpass.reshape_eliminate_,
186 irpass.switch_simplify_,
187 irpass.addn_zero_filter_,
188 });
189 opt::OptPassConfig fill_zeros_like = opt::OptPassConfig{irpass.zero_like_fill_zero_};
190 OptPassGroupMap map({
191 {"ad_final_opt", bg_final_opt},
192 {"zeros_like", fill_zeros_like},
193 });
194
195 if (pynative::PynativeExecutor::GetInstance()->grad_executor()->need_renormalize()) {
196 (void)map.emplace_back(std::make_pair("renormalize", opt::OptPassConfig::Renormalize()));
197 opt::OptPassConfig env_eliminate = opt::OptPassConfig({
198 irpass.incorporate_call_,
199 irpass.incorporate_call_switch_,
200 irpass.incorporate_getitem_set_,
201 });
202 (void)map.emplace_back(std::make_pair("env_eliminate", env_eliminate));
203 }
204
205 auto bprop_graph_final_opt = opt::Optimizer::MakeOptimizer("bprop_graph_final_opt", res, map);
206 FuncGraphPtr func_graph = res->func_graph();
207 WITH(MsProfile::GetProfile()->Step("bprop_graph_final_opt"))[&bprop_graph_final_opt, &func_graph]() {
208 func_graph = bprop_graph_final_opt->step(func_graph, true);
209 };
210
211 return func_graph;
212 }
213
214 namespace {
ReAutoMonadWrapper(const FuncGraphPtr & root,const opt::OptimizerPtr &)215 bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); }
216
parallel_mode()217 bool parallel_mode() {
218 #if ((defined ENABLE_CPU) && (!defined _WIN32))
219 if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) {
220 return false;
221 }
222 #endif
223 std::string parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
224 return (parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL);
225 }
226
AddParallelRenormalize(OptPassGroupMap * map_a)227 void AddParallelRenormalize(OptPassGroupMap *map_a) {
228 if (parallel_mode()) {
229 auto parallel_end_opt =
230 find_if(map_a->begin(), map_a->end(), [](auto opt_pair) { return opt_pair.first == "grad"; });
231 if (parallel_end_opt != map_a->end()) {
232 (void)map_a->insert(parallel_end_opt, {"parallel_renormalize", opt::OptPassConfig::Renormalize()});
233 }
234 }
235 }
236
GetOptPassA1(const opt::irpass::OptimizeIRPassLib & irpass)237 opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
238 return opt::OptPassConfig({
239 irpass.switch_defer_inline_,
240 irpass.switch_layer_defer_inline_,
241 irpass.switch_simplify_,
242 irpass.exchange_switch_depend_value_,
243 irpass.float_depend_g_call_,
244
245 // Safe inlining
246 irpass.inline_,
247 irpass.updatestate_useless_node_eliminater_,
248 irpass.updatestate_pure_node_eliminater_,
249 irpass.load_eliminater_,
250 irpass.stopgrad_eliminater_,
251 irpass.partial_eliminate_,
252 irpass.replace_applicator_,
253
254 // Miscellaneous
255 irpass.tuple_list_get_item_eliminator_,
256 irpass.tuple_list_get_item_const_eliminator_,
257 irpass.tuple_list_set_item_eliminator_,
258 irpass.tuple_list_get_set_item_eliminator_,
259 irpass.tuple_list_get_item_depend_reorder_,
260 irpass.tuple_list_convert_item_index_to_positive_,
261
262 irpass.env_get_item_eliminate_,
263 irpass.env_get_item_add_eliminate_,
264 irpass.env_get_set_item_eliminate_,
265 irpass.env_get_item_depend_swap_,
266
267 irpass.cast_eliminate_,
268 irpass.reshape_eliminate_,
269 irpass.reduce_eliminate_,
270 irpass.tile_eliminate_,
271 irpass.transpose_eliminate_,
272 irpass.minmaximum_grad_,
273 irpass.get_make_ref_eliminate_,
274
275 // Arithmetic simplifications
276 irpass.arithmetic_simplify_,
277 irpass.addn_zero_filter_,
278 irpass.adjust_all_reduce_mul_add_,
279 irpass.accumulaten_eliminater_,
280
281 // Safe inlining
282 irpass.inline_,
283 irpass.updatestate_useless_node_eliminater_,
284 irpass.updatestate_pure_node_eliminater_,
285 irpass.load_eliminater_,
286 irpass.stopgrad_eliminater_,
287 irpass.sparse_tensor_eliminate_,
288 });
289 }
290
GetOptPassesA(const opt::irpass::OptimizeIRPassLib & irpass)291 OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
292 opt::OptPassConfig a_1 = GetOptPassA1(irpass);
293 opt::OptPassConfig a_2 = opt::OptPassConfig(
294 {
295 irpass.switch_simplify_,
296 irpass.specialize_transform_,
297 irpass.merge_addn_,
298 irpass.float_tuple_getitem_switch_,
299 irpass.float_env_getitem_switch_,
300 irpass.inline_,
301 irpass.incorporate_getitem_set_,
302 irpass.incorporate_call_,
303 irpass.incorporate_call_switch_,
304 irpass.incorporate_env_getitem_bypass_recursive_,
305 irpass.incorporate_env_getitem_switch_,
306 irpass.env_get_item_eliminate_,
307 irpass.depend_value_elim_,
308 irpass.all_reduce_const_elim_,
309 },
310 false, true);
311
312 opt::OptPassConfig a_after_grad = opt::OptPassConfig({irpass.inline_without_move_});
313
314 opt::OptPassConfig a_3 = opt::OptPassConfig(
315 {
316 irpass.arithmetic_simplify2_,
317 irpass.same_eliminate_,
318 irpass.check_bprop_eliminate_,
319 irpass.switch_layer_defer_inline_,
320 irpass.replace_applicator_,
321 irpass.mirror_mini_step_elim_,
322 irpass.virtual_add_elim_,
323 irpass.row_tensor_add_zeros_like_,
324 irpass.mini_step_allgather_replace_,
325 irpass.micro_step_allgather_replace_,
326 },
327 false, true);
328 opt::OptPassConfig accelerated_algorithm = opt::OptPassConfig({irpass.less_batch_normalization_});
329 opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
330 opt::OptPassConfig after_resolve_pass =
331 opt::OptPassConfig({irpass.get_make_ref_eliminate_, irpass.replace_old_param_});
332 opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
333 opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
334 opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
335
336 // Before adjusting map_a, check GetA1A2() and GetOptPynativeGradEpiloguePhases().
337 OptPassGroupMap map_a({{"switch_simplify", opt::OptPassConfig({irpass.switch_simplify_})},
338 {"a_1", a_1},
339 {"updatestate_depend_eliminate", updatestate_depend_eliminate},
340 {"updatestate_assign_eliminate", updatestate_assign_eliminate},
341 {"updatestate_loads_eliminate", updatestate_loads_eliminate},
342 {"parameter_eliminate", opt::OptPassConfig(opt::irpass::ParameterEliminator())},
343 {"a_2", a_2},
344 {"accelerated_algorithm", accelerated_algorithm},
345 {"auto_parallel", opt::OptPassConfig(parallel::StepAutoParallel)},
346 {"parallel", opt::OptPassConfig(parallel::StepParallel)},
347 {"allreduce_fusion", opt::OptPassConfig(parallel::StepAllreduceFusion)},
348 {"virtual_dataset", virtual_dataset},
349 {"virtual_output", opt::OptPassConfig({irpass.virtual_output_eliminate_})},
350 {"grad", opt::OptPassConfig(opt::irpass::ExpandJPrim())},
351 {"after_resolve", after_resolve_pass},
352 {"a_after_grad", a_after_grad},
353 {"renormalize", opt::OptPassConfig::Renormalize()},
354 {"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)},
355 {"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
356 {"cse", opt::OptPassConfig(opt::CSEPass(false))},
357 {"a_3", a_3}});
358 AddParallelRenormalize(&map_a);
359 return map_a;
360 }
361
GetA1A2(const opt::irpass::OptimizeIRPassLib & irpass)362 OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
363 auto opt_a = GetOptPassesA(irpass);
364 constexpr auto a1_a2_len = 7;
365 OptPassGroupMap a1_a2(opt_a.begin(), opt_a.begin() + a1_a2_len);
366 return a1_a2;
367 }
368
GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib & irpass)369 OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
370 opt::OptPassConfig c_1 = opt::OptPassConfig({
371 // Safe inlining,
372 irpass.inline_,
373 irpass.updatestate_useless_node_eliminater_,
374 irpass.updatestate_pure_node_eliminater_,
375 irpass.load_eliminater_,
376 irpass.switch_call_monad_eliminater_,
377 irpass.stopgrad_eliminater_,
378 irpass.partial_eliminate_,
379 });
380 opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
381 opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
382 opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
383
384 OptPassGroupMap map_a({{"c_1", c_1},
385 {"updatestate_depend_eliminate", updatestate_depend_eliminate},
386 {"updatestate_assign_eliminate", updatestate_assign_eliminate},
387 {"updatestate_loads_eliminate", updatestate_loads_eliminate},
388 {"cse", opt::OptPassConfig(opt::CSEPass(false))},
389 {"renormalize", opt::OptPassConfig::Renormalize()}});
390
391 return map_a;
392 }
393
GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib & irpass)394 OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
395 opt::OptPassConfig d_1 =
396 opt::OptPassConfig({irpass.call_graph_tuple_transform_, irpass.tuple_list_get_item_eliminator_,
397 irpass.tuple_list_get_item_const_eliminator_, irpass.tuple_list_set_item_eliminator_,
398 irpass.tuple_list_get_set_item_eliminator_, irpass.tuple_list_get_item_depend_reorder_,
399 irpass.tuple_list_convert_item_index_to_positive_});
400
401 OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
402
403 return map_a;
404 }
405
GetOptPassesB(const opt::irpass::OptimizeIRPassLib & irpass)406 OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
407 opt::OptPassConfig b_1 = opt::OptPassConfig({irpass.zero_like_fill_zero_,
408 irpass.tuple_list_get_item_eliminator_,
409 irpass.tuple_list_get_item_const_eliminator_,
410 irpass.tuple_list_set_item_eliminator_,
411 irpass.tuple_list_get_set_item_eliminator_,
412 irpass.tuple_list_get_item_depend_reorder_,
413 irpass.tuple_list_convert_item_index_to_positive_,
414 irpass.float_tuple_getitem_switch_,
415 irpass.reset_defer_inline_,
416 irpass.inline_,
417 irpass.updatestate_useless_node_eliminater_,
418 irpass.updatestate_pure_node_eliminater_,
419 irpass.load_eliminater_,
420 irpass.stopgrad_eliminater_,
421 irpass.special_op_eliminate_,
422 irpass.get_make_ref_eliminate_,
423 irpass.incorporate_env_getitem_,
424 irpass.incorporate_env_getitem_switch_,
425 irpass.env_get_item_eliminate_,
426 irpass.env_get_item_add_eliminate_,
427 irpass.env_get_set_item_eliminate_,
428 irpass.env_get_item_depend_swap_,
429 irpass.incorporate_env_getitem_switch_layer_,
430 irpass.value_based_eliminate_,
431 irpass.virtual_accu_grad_,
432 irpass.virtual_assign_add_,
433 irpass.mirror_micro_step_},
434 false, true);
435 opt::OptPassConfig b_2 = opt::OptPassConfig({
436 irpass.replace_refkey_by_param_,
437 irpass.make_ref_eliminate_,
438 irpass.get_ref_param_eliminate_,
439 irpass.row_tensor_eliminate_,
440 });
441 opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
442 opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
443 opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
444 OptPassGroupMap map({
445 {"b_1", b_1},
446 {"b_2", b_2},
447 {"updatestate_depend_eliminate", updatestate_depend_eliminate},
448 {"updatestate_assign_eliminate", updatestate_assign_eliminate},
449 {"updatestate_loads_eliminate", updatestate_loads_eliminate},
450 {"renormalize", opt::OptPassConfig::Renormalize()},
451 {"cse", opt::OptPassConfig(opt::CSEPass(false))},
452 });
453 return map;
454 }
455
GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib & irpass)456 OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &irpass) {
457 opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
458 irpass.pynative_eliminate_,
459 });
460
461 OptPassGroupMap map({
462 {"pynative_eliminate", pynative_eliminate},
463 });
464 return map;
465 }
466
GetOptPassesC(const opt::irpass::OptimizeIRPassLib &)467 OptPassGroupMap GetOptPassesC(const opt::irpass::OptimizeIRPassLib &) {
468 return OptPassGroupMap({{"renormalize", opt::OptPassConfig::Renormalize()}});
469 }
470
GetControlPhases(const opt::irpass::OptimizeIRPassLib &)471 OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &) {
472 opt::OptPassConfig control_group = opt::OptPassConfig(opt::irpass::ConvertSwitchReplacement());
473 OptPassGroupMap map({
474 {"control_group", control_group},
475 {"renormalize", opt::OptPassConfig::Renormalize()},
476 });
477 return map;
478 }
479
GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib & irpass)480 OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
481 auto opt_a = GetOptPassesA(irpass);
482 auto a3 = opt_a[opt_a.size() - 1];
483 OptPassGroupMap map({
484 {"renormalize", opt::OptPassConfig::Renormalize()},
485 {"cse", opt::OptPassConfig(opt::CSEPass(false))},
486 {a3},
487 });
488 return map;
489 }
490
GetInferenceOptPreparePhases()491 OptPassGroupMap GetInferenceOptPreparePhases() {
492 opt::irpass::InferenceOptPrepareLib irpass;
493 auto grad_var_prepare = opt::OptPassConfig({irpass.grad_var_prepare_});
494 opt::OptPassGroupMap prepare_map({{"inference_opt_prep", grad_var_prepare}});
495 return prepare_map;
496 }
497
GetPreparePhases(const opt::irpass::OptimizeIRPassLib & irpass)498 OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
499 opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_});
500 OptPassGroupMap map({{"prepare_group", prepare_group}});
501 return map;
502 }
503
GetBeforeRecomputePass(const opt::irpass::OptimizeIRPassLib & irpass)504 OptPassGroupMap GetBeforeRecomputePass(const opt::irpass::OptimizeIRPassLib &irpass) {
505 opt::OptPassConfig set_cell_output_no_recompute = opt::OptPassConfig({irpass.set_cell_output_no_recompute_});
506 OptPassGroupMap map({{"set_cell_output_no_recompute", set_cell_output_no_recompute}});
507 return map;
508 }
509
GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &)510 OptPassGroupMap GetAfterRecomputePass(const opt::irpass::OptimizeIRPassLib &) {
511 OptPassGroupMap map({{"cse", opt::OptPassConfig(opt::CSEPass(false))}});
512 return map;
513 }
514
515 static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts = {};
516
InitOpt(const ResourcePtr & res)517 void InitOpt(const ResourcePtr &res) {
518 if (g_pass_opts.size() == 0) {
519 opt::irpass::OptimizeIRPassLib irpass;
520 g_pass_opts["a1a2"] = Optimizer::MakeOptimizer("a1a2", res, GetA1A2(irpass));
521 g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
522 g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
523 g_pass_opts["opt_after_cconv"] =
524 Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
525 g_pass_opts["opt_trans_graph"] =
526 Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true);
527 g_pass_opts["renormal"] = Optimizer::MakeOptimizer("renormal", res, GetOptPassesC(irpass));
528 g_pass_opts["opt_control"] = Optimizer::MakeOptimizer("opt_control", res, GetControlPhases(irpass), true, true);
529 g_pass_opts["opt_grad_epilogue"] =
530 Optimizer::MakeOptimizer("opt_grad_epilogue", res, GetOptPynativeGradEpiloguePhases(irpass), true, false);
531 g_pass_opts["opt_prepare"] = Optimizer::MakeOptimizer("opt_prepare", res, GetPreparePhases(irpass));
532 g_pass_opts["opt_before_recompute"] =
533 Optimizer::MakeOptimizer("opt_before_recompute", res, GetBeforeRecomputePass(irpass));
534 g_pass_opts["opt_after_recompute"] =
535 Optimizer::MakeOptimizer("opt_after_recompute", res, GetAfterRecomputePass(irpass));
536 }
537 }
538 } // namespace
539
ReclaimOptimizer()540 void ReclaimOptimizer() {
541 for (auto &opt : g_pass_opts) {
542 opt.second = nullptr;
543 }
544 g_pass_opts.clear();
545 }
546
OptPassGroup(const ResourcePtr & res,const std::string & name)547 bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
548 MS_EXCEPTION_IF_NULL(res);
549 if (res->func_graph() == nullptr) {
550 MS_LOG(ERROR) << "Opt passes int64_t error";
551 return false;
552 }
553
554 FuncGraphPtr func_graph = res->func_graph();
555 MS_LOG(DEBUG) << "Start " << name << " func graph:" << func_graph->ToString() << ", "
556 << func_graph->get_return()->DebugString(true);
557 InitOpt(res);
558 if (g_pass_opts.find(name) != g_pass_opts.end()) {
559 res->set_func_graph(g_pass_opts[name]->step(func_graph));
560 }
561 // Note: StepParallel may modify the AbstractValue of the parameters of func_graph, but they are not updated to
562 // res->args_spec_ yet. So if any later pass or action want to use that variable, it should be set here.
563 return true;
564 }
565
OptPassA1A2(const ResourcePtr & res)566 bool OptPassA1A2(const ResourcePtr &res) { return OptPassGroup(res, "a1a2"); }
OptPassAGroup(const ResourcePtr & res)567 bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
OptPassBGroup(const ResourcePtr & res)568 bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
OptPassAfterCconvGroup(const ResourcePtr & res)569 bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
OptPassTransformGraphGroup(const ResourcePtr & res)570 bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
ControlGroup(const ResourcePtr & res)571 bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
PrepareGroup(const ResourcePtr & res)572 bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); }
OptBeforeRecomputeGroup(const ResourcePtr & res)573 bool OptBeforeRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_before_recompute"); }
OptAfterRecomputeGroup(const ResourcePtr & res)574 bool OptAfterRecomputeGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_recompute"); }
575
OptPassRNGroup(const ResourcePtr & res)576 bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal"); }
577
OptPassGradEpilogueGroup(const ResourcePtr & res)578 bool OptPassGradEpilogueGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_grad_epilogue"); }
579
AddRecomputationPass(const ResourcePtr & res)580 bool AddRecomputationPass(const ResourcePtr &res) {
581 MS_EXCEPTION_IF_NULL(res);
582 opt::InsertRecomputedNodes(res->func_graph());
583 return true;
584 }
585
AddCacheEmbeddingPass(const ResourcePtr & res)586 bool AddCacheEmbeddingPass(const ResourcePtr &res) {
587 MS_EXCEPTION_IF_NULL(res);
588 #if ((defined ENABLE_CPU) && (!defined _WIN32))
589 if (ps::PSContext::instance()->is_ps_mode()) {
590 return true;
591 }
592 #endif
593 FuncGraphPtr func_graph = res->func_graph();
594 MS_EXCEPTION_IF_NULL(func_graph);
595
596 parallel::AddCacheEmbedding(func_graph);
597 if (func_graph->has_flag(GRAPH_FLAG_CACHE_ENABLE)) {
598 auto params = func_graph->parameters();
599 AbstractBasePtrList args_spec_list;
600 std::for_each(params.begin(), params.end(),
601 [&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); });
602 func_graph = pipeline::Renormalize(res, func_graph, args_spec_list);
603 }
604 return true;
605 }
606
RemoveValueNodeDuplicationsPass(const ResourcePtr & res)607 bool RemoveValueNodeDuplicationsPass(const ResourcePtr &res) {
608 MS_EXCEPTION_IF_NULL(res);
609 if (res->func_graph() == nullptr) {
610 MS_LOG(EXCEPTION) << "Remove value node duplications error.";
611 }
612 auto manager = res->manager();
613 HashCache hash_cache;
614 HashValue hashes;
615 // Remove duplicated value nodes across all graphs in manager
616 auto node_user_map = manager->node_users();
617 for (auto &fg : manager->func_graphs()) {
618 auto value_nodes = fg->value_nodes();
619 for (const auto &value_pair : value_nodes) {
620 auto users = node_user_map[value_pair.first];
621 // For data parallel with some parameters redundant, the allreduce will share the same value node
622 // which will raise an error when do allreduce fusion, so the solution is to make the allreduce's value node
623 // not be removed, if we found the fusion tag.
624 if (users.size() == 1) {
625 auto cnode = users.front().first->cast<CNodePtr>();
626 if (IsPrimitiveCNode(cnode, prim::kPrimAllReduce) && cnode->inputs().size() > 1 &&
627 cnode->input(1)->isa<ValueNode>()) {
628 auto allreduce_prim = GetCNodePrimitive(users.front().first);
629 auto attrs = allreduce_prim->attrs();
630 auto fusion_id = attrs.find(mindspore::parallel::FUSION);
631 if (fusion_id != attrs.end() && GetValue<int64_t>(fusion_id->second) > 0) {
632 continue;
633 }
634 }
635 }
636 TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
637 }
638 }
639 return true;
640 }
641
CconvPass(const ResourcePtr & res)642 bool CconvPass(const ResourcePtr &res) {
643 MS_EXCEPTION_IF_NULL(res);
644 MS_EXCEPTION_IF_NULL(res->func_graph());
645 FuncGraphPtr func_graph = res->func_graph();
646 FuncGraphPtr new_fg = LiftingClone(func_graph);
647 res->set_func_graph(new_fg);
648 return true;
649 }
650
PipelineSplitPass(const ResourcePtr & res)651 bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); }
652
ValidatePass(const ResourcePtr & res)653 bool ValidatePass(const ResourcePtr &res) {
654 MS_EXCEPTION_IF_NULL(res);
655 MS_EXCEPTION_IF_NULL(res->func_graph());
656 FuncGraphPtr func_graph = res->func_graph();
657 Validate(func_graph);
658 return true;
659 }
660
InferenceOptPreparePass(const ResourcePtr & res)661 bool InferenceOptPreparePass(const ResourcePtr &res) {
662 FuncGraphPtr func_graph = res->func_graph();
663 MS_EXCEPTION_IF_NULL(func_graph);
664 auto prepare_map = GetInferenceOptPreparePhases();
665 auto infer_opt_prepare = opt::Optimizer::MakeOptimizer("inference_prepare", res, prepare_map);
666 (void)infer_opt_prepare->step(func_graph, false);
667 return true;
668 }
669
PynativeOptPass(const ResourcePtr & res)670 bool PynativeOptPass(const ResourcePtr &res) {
671 FuncGraphPtr func_graph = res->func_graph();
672 MS_EXCEPTION_IF_NULL(func_graph);
673 opt::irpass::OptimizeIRPassLib irpass;
674 auto pynative_opt = GetOptPassesPynativeElim(irpass);
675 auto pynative_opt_opt = opt::Optimizer::MakeOptimizer("pynative_opt", res, pynative_opt);
676 (void)pynative_opt_opt->step(func_graph, false);
677 return true;
678 }
679
AutoMonadElimOptPass(const FuncGraphPtr & func_graph)680 bool AutoMonadElimOptPass(const FuncGraphPtr &func_graph) {
681 MS_EXCEPTION_IF_NULL(func_graph);
682 MS_EXCEPTION_IF_NULL(func_graph->manager());
683 auto res = std::make_shared<pipeline::Resource>();
684 res->set_func_graph(func_graph);
685 res->set_manager(func_graph->manager());
686
687 // opt::irpass::OptimizeIRPassLib is not used here to avoid double free problems in external calls.
688 opt::SubstitutionPtr updatestate_useless_node_eliminater =
689 opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestateUselessNodeEliminater>(),
690 "updatestate_useless_node_eliminater", prim::kPrimUpdateState);
691 opt::SubstitutionPtr updatestate_pure_node_eliminater =
692 opt::MakeSubstitution(std::make_shared<opt::irpass::UpdatestatePureNodeEliminater>(),
693 "updatestate_pure_node_eliminater", prim::kPrimUpdateState);
694
695 opt::OptPassConfig updatestate_eliminater = opt::OptPassConfig({
696 updatestate_useless_node_eliminater,
697 updatestate_pure_node_eliminater,
698 });
699 opt::OptPassConfig updatestate_depend_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateDependEliminater());
700 opt::OptPassConfig updatestate_assign_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateAssignEliminater());
701 opt::OptPassConfig updatestate_loads_eliminate = opt::OptPassConfig(opt::irpass::UpdatestateLoadsEliminater());
702 opt::OptPassGroupMap elim_map({
703 {"updatestate_eliminater", updatestate_eliminater},
704 {"updatestate_depend_eliminate", updatestate_depend_eliminate},
705 {"updatestate_assign_eliminate", updatestate_assign_eliminate},
706 {"updatestate_loads_eliminate", updatestate_loads_eliminate},
707 {"auto_monad_eliminator", opt::OptPassConfig(opt::AutoMonadEliminator())},
708 });
709
710 auto auto_monad_elim_opt = opt::Optimizer::MakeOptimizer("auto_monad_elim", res, elim_map);
711 (void)auto_monad_elim_opt->step(func_graph, false);
712 return true;
713 }
714
715 std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
716 {"opt_before_recompute", OptBeforeRecomputeGroup},
717 {"opt_a", OptPassAGroup},
718 {"clean_after_opta", CleanAfterOptAPass},
719 {"opt_b", OptPassBGroup},
720 {"cconv", CconvPass},
721 {"opt_after_cconv", OptPassAfterCconvGroup},
722 {"remove_dup_value", RemoveValueNodeDuplicationsPass},
723 {"tuple_transform", OptPassTransformGraphGroup},
724 {"add_cache_embedding", AddCacheEmbeddingPass},
725 {"add_recomputation", AddRecomputationPass},
726 {"cse_after_recomputation", OptAfterRecomputeGroup}};
727
728 std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
729 {"opt_a", OptPassAGroup},
730 {"clean_after_opta", CleanAfterOptAPass},
731 {"opt_b", OptPassBGroup},
732 {"opt_control", ControlGroup},
733 {"opt_prepare", PrepareGroup},
734 {"cconv", CconvPass}};
735
736 std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
737 {"opt_b", OptPassBGroup},
738 {"cconv", CconvPass},
739 {"transform_top", TransformTopGraphPass},
740 {"transform_graph", OptPassTransformGraphGroup}};
741
742 std::vector<PassItem> kInlinePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"a1a2", OptPassA1A2}};
743 } // namespace pipeline
744 } // namespace mindspore
745