• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/action.h"
18 
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <algorithm>
24 #include <functional>
25 
26 #include "ir/func_graph_cloner.h"
27 #include "ir/param_info.h"
28 #include "ir/cell.h"
29 #include "parse/python_adapter.h"
30 #include "abstract/abstract_value.h"
31 #include "frontend/parallel/costmodel_context.h"
32 #include "frontend/parallel/context.h"
33 #include "pipeline/jit/pass.h"
34 #include "pipeline/jit/parse/parse_base.h"
35 #include "pipeline/jit/parse/data_converter.h"
36 #include "pipeline/jit/static_analysis/auto_monad.h"
37 #include "pipeline/jit/static_analysis/order_enforce.h"
38 #include "pipeline/jit/static_analysis/static_analysis.h"
39 #include "pipeline/jit/static_analysis/async_eval_result.h"
40 #include "pipeline/jit/static_analysis/program_specialize.h"
41 #include "pipeline/jit/resource.h"
42 #include "pipeline/jit/remove_value_node_dup.h"
43 #include "pipeline/pynative/pynative_execute.h"
44 #include "frontend/optimizer/optimizer.h"
45 #include "frontend/optimizer/ad/grad.h"
46 #include "frontend/optimizer/py_pass_manager.h"
47 #include "utils/ms_context.h"
48 #include "vm/transform.h"
49 #if ((defined ENABLE_CPU) && (!defined _WIN32))
50 #include "ps/parameter_server.h"
51 #include "ps/scheduler.h"
52 #include "ps/worker.h"
53 #include "fl/worker/fl_worker.h"
54 #include "fl/server/server.h"
55 #endif
56 
57 namespace mindspore {
58 namespace pipeline {
59 namespace {
UpdateFuncGraphParameter(const FuncGraphPtr & func_graph)60 void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
61   MS_EXCEPTION_IF_NULL(func_graph);
62   std::vector<AnfNodePtr> new_paras;
63   for (const auto &param : func_graph->parameters()) {
64     auto param_node = param->cast<ParameterPtr>();
65     MS_EXCEPTION_IF_NULL(param_node);
66     if (param_node->has_default()) {
67       new_paras.push_back(param_node);
68       continue;
69     }
70     AbstractBasePtr par_abs = param_node->abstract();
71     MS_EXCEPTION_IF_NULL(par_abs);
72     if (par_abs->isa<abstract::AbstractUndetermined>() ||
73         (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
74          par_abs->BuildType()->isa<Number>())) {
75       new_paras.push_back(param_node);
76     }
77   }
78   func_graph->set_parameters(new_paras);
79 }
80 
81 // Disable mindRT in the control flow scenario.
ResetMindRTEnable(const ResourcePtr & res)82 void ResetMindRTEnable(const ResourcePtr &res) {
83   MS_EXCEPTION_IF_NULL(res);
84   auto context_ptr = MsContext::GetInstance();
85   MS_EXCEPTION_IF_NULL(context_ptr);
86   if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) == false) {
87     return;
88   }
89 
90   auto func_graph = res->func_graph();
91   MS_EXCEPTION_IF_NULL(func_graph);
92   if (func_graph != nullptr && func_graph->manager() != nullptr) {
93     auto manager = func_graph->manager();
94     size_t graph_nums = manager->func_graphs().size();
95     if (graph_nums == 1) {
96       return;
97     }
98 
99     MS_LOG(INFO) << "Disable mindRT in the multi graphs scenario.";
100     context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
101     // Update the backend.
102     auto new_backend = compile::CreateBackend();
103     new_backend->SetDebugger();
104     res->results()[kBackend] = new_backend;
105   }
106 }
107 
TaskEmitActionForMindRT(const ResourcePtr & res)108 void TaskEmitActionForMindRT(const ResourcePtr &res) {
109   MS_EXCEPTION_IF_NULL(res);
110   // Get the mindRT backend.
111   auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
112   auto mindrt_bc_ptr = std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr);
113   MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
114 
115   // The output of graph compiler is actor.
116   res->results()[kOutput] = mindrt_bc_ptr->CompileGraphs(res->func_graph());
117 }
118 
ExecuteActionForMindRT(const ResourcePtr & res)119 void ExecuteActionForMindRT(const ResourcePtr &res) {
120   MS_EXCEPTION_IF_NULL(res);
121   if (!res->results()[kOutput].is<compile::ActorInfo>()) {
122     MS_LOG(EXCEPTION) << "Execute args error";
123   }
124   const auto &actor_info = res->results()[kOutput].cast<compile::ActorInfo>();
125 
126   // Get the mindRT backend.
127   std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
128   auto mindrt_bc_ptr = (std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr)).get();
129   MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
130 
131   // Construct the graph run function ptr.
132   compile::VmEvalFuncPtr run =
133     std::make_shared<compile::VmEvalFunc>([mindrt_bc_ptr, actor_info](const VectorRef &args) -> BaseRef {
134       MS_LOG(DEBUG) << "Execute args size " << args.size();
135       VectorRef outputs;
136       mindrt_bc_ptr->RunGraph(actor_info, args, &outputs);
137       MS_LOG(DEBUG) << "out size " << outputs.size();
138       return outputs[0];
139     });
140   res->results()[kOutput] = run;
141 }
142 
143 // Modify the output node of func_graph to add forward nodes used in bprop graph.
ModifyOutputNode(const FuncGraphPtr & func_graph)144 void ModifyOutputNode(const FuncGraphPtr &func_graph) {
145   MS_EXCEPTION_IF_NULL(func_graph);
146   const auto &used_forward_nodes = func_graph->used_forward_nodes();
147 
148   // Get original output node and abstract
149   auto original_output_node = func_graph->output();
150   MS_EXCEPTION_IF_NULL(original_output_node);
151   auto original_output_abs = original_output_node->abstract();
152   MS_EXCEPTION_IF_NULL(original_output_abs);
153 
154   // Create a new make tuple node to hold all forward used nodes.
155   abstract::AbstractBasePtrList added_abs_list;
156   std::vector<AnfNodePtr> added_node_list{NewValueNode(prim::kPrimMakeTuple)};
157   std::for_each(used_forward_nodes.begin(), used_forward_nodes.end(),
158                 [&added_abs_list, &added_node_list](const AnfNodePtr &node) {
159                   MS_EXCEPTION_IF_NULL(node);
160                   added_node_list.push_back(node);
161                   added_abs_list.push_back(node->abstract());
162                 });
163   AnfNodePtr added_output_node = nullptr;
164   AbstractBasePtr added_output_abs = nullptr;
165   if (added_abs_list.empty()) {
166     added_output_node = NewValueNode(MakeValue<int32_t>(1));
167     added_output_abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(1));
168   } else {
169     added_output_node = func_graph->NewCNode(added_node_list);
170     added_output_abs = std::make_shared<abstract::AbstractTuple>(added_abs_list);
171   }
172   added_output_node->set_abstract(added_output_abs);
173   MS_LOG(DEBUG) << "Added output node info: " << added_output_node->DebugString();
174 
175   // Merge original output node and used forward nodes to return node.
176   std::vector<AnfNodePtr> new_output_nodes{NewValueNode(prim::kPrimMakeTuple), original_output_node, added_output_node};
177   auto merge_node = func_graph->NewCNode(new_output_nodes);
178   abstract::AbstractBasePtrList new_output_abs{original_output_abs, added_output_abs};
179   merge_node->set_abstract(std::make_shared<abstract::AbstractTuple>(new_output_abs));
180   MS_LOG(DEBUG) << "Merge node info: " << merge_node->DebugString();
181   func_graph->set_output(merge_node);
182 
183   // Clear
184   func_graph->set_modify_output(true);
185   func_graph->ClearUsedForwardNodes();
186 }
187 }  // namespace
188 using CompileGraphs = compile::CompileGraphs;
189 using abstract::AnalysisResult;
190 using mindspore::abstract::AnalysisContextPtr;
191 
AbstractAnalyze(const ResourcePtr & res,const FuncGraphPtr & func_graph,const abstract::AbstractBasePtrList & args_spec,bool clear)192 abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
193                                          const abstract::AbstractBasePtrList &args_spec, bool clear) {
194   MS_LOG(DEBUG) << "AbstractAnalyze start";
195   auto engine = res->engine();
196   MS_EXCEPTION_IF_NULL(engine);
197   if (clear) {
198     auto manager = res->manager();
199     MS_EXCEPTION_IF_NULL(manager);
200     engine->Clear();
201     for (auto &node : manager->all_nodes()) {
202       MS_EXCEPTION_IF_NULL(node);
203 
204       // Handle previous inferred value for CNode if is loaded from MindIR
205       if (res->is_load()) {
206         // If the primitive is not defined in front end,keep the inferred value loaded from MindIR.
207         auto primitive = GetCNodePrimitive(node);
208         if (primitive != nullptr && abstract::GetPrimEvaluator(primitive, engine) == nullptr) {
209           MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
210           continue;
211         }
212       }
213 
214       const AbstractBasePtr &prev_inferred = node->abstract();
215       // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
216       if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
217         node->set_abstract(nullptr);
218         MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
219       }
220     }
221   }
222   auto ret = engine->Run(func_graph, args_spec);
223   MS_LOG(INFO) << "function call max depth: " << abstract::FunctionCallMaxDepth()
224                << ", simulate call max depth: " << abstract::StackFrameMaxDepth();
225   MS_LOG(DEBUG) << "AbstractAnalyze end";
226   return ret;
227 }
228 
ProgramSpecialize(const ResourcePtr & res,const FuncGraphPtr & func_graph,const abstract::AnalysisContextPtr & context)229 FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
230                                const abstract::AnalysisContextPtr &context) {
231   MS_EXCEPTION_IF_NULL(res);
232   MS_LOG(DEBUG) << "ProgramSpecialize start";
233   abstract::ProgramSpecializer spc(res->engine());
234   FuncGraphPtr result = spc.Run(func_graph, context);
235   auto manager = res->manager();
236   MS_EXCEPTION_IF_NULL(manager);
237   manager->KeepRoots({result});
238   MS_LOG(DEBUG) << "ProgramSpecialize end";
239   return result;
240 }
241 
Renormalize(const ResourcePtr & res,const FuncGraphPtr & func_graph,const abstract::AbstractBasePtrList & args_spec)242 FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
243                          const abstract::AbstractBasePtrList &args_spec) {
244   MS_EXCEPTION_IF_NULL(res);
245   MS_LOG(DEBUG) << "Renormalize start";
246 #ifdef ENABLE_PROFILE
247   double t1 = GetTime();
248 #endif
249   abstract::AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec, true);
250 #ifdef ENABLE_PROFILE
251   double t2 = GetTime();
252 #endif
253   auto ret = ProgramSpecialize(res, func_graph, result.context);
254   res->set_func_graph(ret);
255 #ifdef ENABLE_PROFILE
256   double t3 = GetTime();
257   MsProfile::StatTime("renormalize.infer", t2 - t1);
258   MsProfile::StatTime("renormalize.specialize", t3 - t2);
259 #endif
260 
261   MS_LOG(DEBUG) << "Renormalize end";
262 
263   return ret;
264 }
265 
GetLoadedGraph(const ResourcePtr & res)266 const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) {
267   MS_EXCEPTION_IF_NULL(res);
268   auto manager = res->manager();
269   MS_EXCEPTION_IF_NULL(manager);
270   FuncGraphPtr loaded_graph = nullptr;
271   size_t loaded_graph_num = 0;
272   auto all_graphs = manager->func_graphs();
273   for (auto &graph : all_graphs) {
274     MS_EXCEPTION_IF_NULL(graph);
275     if (graph->has_attr("is_load")) {
276       loaded_graph = graph;
277       loaded_graph_num += 1;
278       res->set_is_load(true);
279     }
280   }
281   if (loaded_graph_num == 0) {
282     return nullptr;
283   }
284   if (loaded_graph_num == 1) {
285     return loaded_graph;
286   }
287   MS_LOG(EXCEPTION) << "The loaded sub graph currently should less than 2, but got " << loaded_graph_num;
288 }
289 
CheckRootInputShapeAndType(const ResourcePtr & res,const FuncGraphPtr & loaded_graph)290 void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &loaded_graph) {
291   MS_EXCEPTION_IF_NULL(res);
292   auto manager = res->manager();
293   MS_EXCEPTION_IF_NULL(manager);
294   FuncGraphPtr root_graph = *(manager->roots().begin());
295   auto root_inputs = root_graph->get_inputs();
296   auto loaded_inputs = loaded_graph->get_inputs();
297   MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString();
298   MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString();
299   size_t root_inputs_num = root_inputs.size();
300   size_t loaded_inputs_num = loaded_inputs.size();
301   if (root_inputs_num != loaded_inputs_num) {
302     MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph "
303                       << loaded_inputs_num;
304   }
305   for (size_t index = 0; index < root_inputs_num; index++) {
306     auto root_input = root_inputs[index];
307     auto loaded_input = loaded_inputs[index];
308 
309     MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1);
310     MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1);
311     MS_LOG(DEBUG) << "root_input abstract[" << index
312                   << "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL");
313     MS_LOG(DEBUG) << "loaded_input abstract [" << index
314                   << "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL");
315 
316     auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
317     auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
318     auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
319     auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
320 
321     MS_EXCEPTION_IF_NULL(root_shape);
322     MS_EXCEPTION_IF_NULL(loaded_shape);
323     MS_EXCEPTION_IF_NULL(root_type);
324     MS_EXCEPTION_IF_NULL(loaded_type);
325 
326     auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) ||
327                     (root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1);
328     if (!shapeEqu) {
329       MS_EXCEPTION(ValueError) << "The " << index
330                                << " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
331                                << ", input shape of loaded graph: " << loaded_shape->ToString();
332     }
333     if (root_type->type_id() != loaded_type->type_id()) {
334       MS_EXCEPTION(TypeError) << "The " << std::to_string(index)
335                               << " th input type differ from loaded graph. Input type: " << root_type->ToString()
336                               << ", input type of loaded graph: " << loaded_type->ToString();
337     }
338   }
339 }
340 
ParseAction(const ResourcePtr & res)341 bool ParseAction(const ResourcePtr &res) {
342   MS_EXCEPTION_IF_NULL(res);
343   if (!res->source_input()) {
344     MS_LOG(EXCEPTION) << "Parse error";
345   }
346 
347   py::object input = res->source_input();
348   parse::Parser::InitParserEnvironment(input);
349   py::module path = py::module::import("os.path");
350   std::string dir = path.attr("dirname")(py::globals()["__file__"]).cast<std::string>();
351 
352   parse::python_adapter::set_python_env_flag(true);
353   parse::python_adapter::SetPythonPath(dir);
354 
355   ValuePtr converted_ret = nullptr;
356   bool converted = parse::ConvertData(input, &converted_ret, true);
357   if (!converted) {
358     MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(input));
359   }
360 
361   FuncGraphPtr top_graph = nullptr;
362   if (py::isinstance<Cell>(input)) {
363     top_graph = parse::MakeTopGraph(input, converted_ret);
364   } else if (converted_ret->isa<FuncGraph>()) {
365     top_graph = converted_ret->cast<FuncGraphPtr>();
366   } else {
367     MS_LOG(EXCEPTION) << "Object to parse " << std::string(py::str(input)) << " is not function or cell.";
368   }
369   parse::Parser::UpdateTopFuncGraph(top_graph);
370 
371   res->set_func_graph(top_graph);
372 
373   FuncGraphManagerPtr manager = res->manager();
374   if (manager == nullptr) {
375     MS_LOG(EXCEPTION) << "Manager is nullptr.";
376   }
377   manager->AddFuncGraph(top_graph);
378   return true;
379 }
380 
381 // obj_map's graphs have the same construct, these graphs can be optimized to one graph.
382 // This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}->
383 // graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx}
384 // all obj_map's graph shared base_graph
CombineLikeGraphs(const ResourcePtr & res)385 bool CombineLikeGraphs(const ResourcePtr &res) {
386   MS_EXCEPTION_IF_NULL(res);
387   auto &obj_map = parse::data_converter::GetObjGraphs();
388   for (auto it : obj_map) {
389     auto &graphs = it.second;
390     MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size();
391     auto fg = graphs[0];
392     FuncGraphVector func_graphs = {fg};
393     ClonerPtr cloner = std::make_shared<Cloner>(func_graphs, false, false, true, std::make_shared<TraceCopy>(),
394                                                 std::make_shared<TraceCombileLikeGraphs>());
395     cloner->Run();
396     auto base_graph = cloner->cloned_func_graph()[fg];
397     MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
398 
399     if (fg->paramter_obj_nodes().empty() || graphs.size() <= 1 || fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
400       continue;
401     }
402     auto &cloned_nodes = *cloner->cloned_node();
403     for (auto &fv : fg->paramter_obj_nodes()) {
404       TraceGuard guard(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
405       auto param = base_graph->add_parameter();
406       MS_EXCEPTION_IF_NULL(res->manager());
407       auto &node_users = res->manager()->node_users()[fv];
408       for (auto &n : node_users) {
409         // If the user is not in this graph, no need to change.
410         auto cloned = cloned_nodes[n.first];
411         if (cloned == nullptr) {
412           continue;
413         }
414         auto repl_n = cloned->cast<CNodePtr>();
415         MS_EXCEPTION_IF_NULL(repl_n);
416         repl_n->set_input(IntToSize(n.second), param);
417       }
418     }
419     MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size();
420 
421     for (auto &g : graphs) {
422       auto &fvs = g->paramter_obj_nodes();
423       std::vector<AnfNodePtr> new_node_inputs;
424       new_node_inputs.push_back(NewValueNode(base_graph));
425       for (auto &p : g->parameters()) {
426         AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p);
427         new_node_inputs.push_back(para_after_cast);
428       }
429       (void)new_node_inputs.insert(new_node_inputs.end(), fvs.begin(), fvs.end());
430       AnfNodePtr out = g->NewCNodeBefore(g->get_return(), new_node_inputs);
431       g->set_output(out);
432       const int recursive_level = 4;
433       MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(recursive_level);
434     }
435     MS_LOG(DEBUG) << "End combine graph:" << it.first;
436   }
437   return true;
438 }
439 
SymbolResolveAction(const ResourcePtr & res)440 bool SymbolResolveAction(const ResourcePtr &res) {
441   MS_EXCEPTION_IF_NULL(res);
442   if (res->manager() == nullptr) {
443     MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null";
444   }
445   auto func_graph = res->func_graph();
446   if (func_graph == nullptr) {
447     MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null";
448   }
449   bool ret = parse::ResolveFuncGraph(func_graph, res);
450   // Remove unused nodes in cnode order list.
451   if (func_graph) {
452     func_graph->EraseUnusedNodeInOrder();
453     for (auto fg : func_graph->func_graphs_used_total()) {
454       if (fg) {
455         fg->EraseUnusedNodeInOrder();
456       }
457     }
458   }
459   return ret;
460 }
461 
AutoMonadAction(const ResourcePtr & res)462 bool AutoMonadAction(const ResourcePtr &res) {
463   MS_EXCEPTION_IF_NULL(res);
464   if (res->manager() == nullptr) {
465     MS_LOG(EXCEPTION) << "Auto-Monad failed, manager is null";
466   }
467   auto func_graph = res->func_graph();
468   if (func_graph == nullptr) {
469     MS_LOG(EXCEPTION) << "Auto-Monad failed, graph is null";
470   }
471   (void)pipeline::AutoMonad(func_graph);
472   return true;
473 }
474 
OrderEnforceAction(const ResourcePtr & res)475 bool OrderEnforceAction(const ResourcePtr &res) {
476   MS_EXCEPTION_IF_NULL(res);
477   if (res->manager() == nullptr) {
478     MS_LOG(EXCEPTION) << "Order-Enforce error, manager is null";
479   }
480   auto func_graph = res->func_graph();
481   if (func_graph == nullptr) {
482     MS_LOG(EXCEPTION) << "Order-Enforce error, graph is null";
483   }
484   pipeline::OrderEnforce(func_graph);
485   return true;
486 }
487 
InferenceOptPrepareAction(const ResourcePtr & res)488 bool InferenceOptPrepareAction(const ResourcePtr &res) {
489   MS_EXCEPTION_IF_NULL(res);
490   if (res->manager() == nullptr) {
491     MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
492   }
493   if (res->func_graph() == nullptr) {
494     MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null.";
495   }
496   return InferenceOptPreparePass(res);
497 }
498 
AbstractSpecializeAction(const ResourcePtr & res)499 bool AbstractSpecializeAction(const ResourcePtr &res) {
500   MS_EXCEPTION_IF_NULL(res);
501   if (res->func_graph() == nullptr) {
502     MS_LOG(EXCEPTION) << "AbstractSpecialize error";
503   }
504   FuncGraphPtr func_graph = res->func_graph();
505   abstract::AbstractBasePtrList args_spec = res->args_spec();
506   auto context = parallel::ParallelContext::GetInstance();
507   MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
508   context->ParallelParameterContextInitShape(func_graph);
509 
510   // Get original loaded graph to check inputs later
511   auto loaded_graph_ptr = GetLoadedGraph(res);
512   // suppose that there is not KeywordArgument for the top graph
513   // get the hyper parameter
514   for (const auto &param : func_graph->parameters()) {
515     auto param_node = std::static_pointer_cast<Parameter>(param);
516     MS_EXCEPTION_IF_NULL(param_node);
517     if (param_node->has_default()) {
518       auto value = param_node->default_param();
519       MS_EXCEPTION_IF_NULL(value);
520       auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
521       auto ref_key = std::make_shared<RefKey>(param_node->name());
522       auto abs_ref_key = ref_key->ToAbstract();
523       auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
524       context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref);
525       args_spec.push_back(abs_ref);
526       context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref);
527     }
528   }
529   // Analyze
530   AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec);
531 
532   // The top graph may be replaced by infer, update the top graph when the infer is done
533   parse::Parser::UpdateTopFuncGraph(result.context->func_graph());
534 
535   // Specialize
536   FuncGraphPtr new_fg = ProgramSpecialize(res, result.context->func_graph(), result.context);
537   res->set_func_graph(new_fg);
538 
539   // Remove unused nodes in cnode order list, this is prepared for auto-monad.
540   if (new_fg) {
541     new_fg->EraseUnusedNodeInOrder();
542     for (auto fg : new_fg->func_graphs_used_total()) {
543       if (fg) {
544         fg->EraseUnusedNodeInOrder();
545       }
546     }
547   }
548   // Check input after abstract when there is a loaded graph
549   if (loaded_graph_ptr != nullptr) {
550     CheckRootInputShapeAndType(res, loaded_graph_ptr);
551   }
552 
553   UpdateFuncGraphParameter(new_fg);
554   MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
555   return true;
556 }
557 
OptimizeAction(const ResourcePtr & res,const std::vector<PassItem> & passes)558 bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) {
559   MS_EXCEPTION_IF_NULL(res);
560   size_t counter = 0;
561   for (auto &pass : passes) {
562     WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() {
563       MS_LOG(DEBUG) << "Pass " << pass.first << " start ...";
564       auto result = pass.second(res);
565       if (!result) {
566         MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first;
567       }
568 #ifdef ENABLE_DUMP_IR
569       if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && res->func_graph() != nullptr) {
570         auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first;
571         auto func_graph = res->func_graph();
572         MS_EXCEPTION_IF_NULL(func_graph);
573         func_graph->DumpFuncGraph(fg_name);
574         DumpIR(fg_name + ".ir", func_graph);
575         ExportIR(fg_name + ".dat", func_graph);
576         MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
577       }
578 #endif
579       counter++;
580       MS_LOG(DEBUG) << "Pass " << pass.first << " end.";
581     };
582   }
583 
584   return true;
585 }
586 
OptInlineAction(const ResourcePtr & res)587 bool OptInlineAction(const ResourcePtr &res) {
588   if (parallel::ParallelContext::GetInstance()->parallel_mode() == "semi_auto_parallel" ||
589       parallel::ParallelContext::GetInstance()->parallel_mode() == "auto_parallel") {
590     return OptimizeAction(res, kInlinePasses);
591   }
592   if (opt::python_pass::PyPassManager::GetInstance()->GetPassGroup(opt::python_pass::Phase::PREAD)->size() != 0) {
593     return OptimizeAction(res, kInlinePasses);
594   }
595   return true;
596 }
597 
GeOptimizeAction(const ResourcePtr & res)598 bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }
599 
VmOptimizeAction(const ResourcePtr & res)600 bool VmOptimizeAction(const ResourcePtr &res) {
601 #if ((defined ENABLE_CPU) && (!defined _WIN32))
602   if (ps::PSContext::instance()->is_ps_mode()) {
603     kVmPasses.push_back({"server_communication_op_fusion", ps::Util::FuseServerCommOps});
604   }
605 #endif
606   return OptimizeAction(res, kVmPasses);
607 }
608 
PynativeElimOpt(const ResourcePtr & res)609 bool PynativeElimOpt(const ResourcePtr &res) {
610   MS_EXCEPTION_IF_NULL(res);
611   if (res->manager() == nullptr) {
612     MS_LOG(EXCEPTION) << "PynativeElimOpt error, manager is null.";
613   }
614   if (res->func_graph() == nullptr) {
615     MS_LOG(EXCEPTION) << "PynativeElimOpt error, graph is null.";
616   }
617   return PynativeOptPass(res);
618 }
619 
IsCtrlSink()620 static bool IsCtrlSink() {
621   auto ms_ctx = MsContext::GetInstance();
622   if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) {
623     return false;
624   }
625 
626   std::string device_target = ms_ctx->get_param<std::string>(MS_CTX_DEVICE_TARGET);
627   if (device_target != kAscendDevice) {
628     return false;
629   }
630 
631   if (!ms_ctx->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
632     return false;
633   }
634 
635   if (!ms_ctx->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
636     return false;
637   }
638   return true;
639 }
640 
CheckGraphOutputConstOrParameter(const FuncGraphPtr & func_graph)641 bool CheckGraphOutputConstOrParameter(const FuncGraphPtr &func_graph) {
642   if (func_graph != nullptr) {
643     AnfNodePtr output = func_graph->output();
644     if (output != nullptr && (output->isa<ValueNode>() || output->isa<Parameter>())) {
645       return true;
646     }
647   }
648   return false;
649 }
650 
EliminateForwardCNode(const ResourcePtr & res)651 bool EliminateForwardCNode(const ResourcePtr &res) {
652   // This function only works in Pynative mode. The func_graph is decorated by ms_function.
653   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
654     return true;
655   }
656 
657   auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
658   MS_EXCEPTION_IF_NULL(graph_executor);
659   auto phase = graph_executor->phase();
660   MS_LOG(DEBUG) << "The phase of current pipeline graph is: " << phase;
661   // Exporting graph in PyNative mode or only running forward process no need to do this action.
662   auto pynative_exec = pynative::PynativeExecutor::GetInstance();
663   if (phase.find("export") == 0 || !pynative_exec->grad_flag()) {
664     MS_LOG(DEBUG) << "When exporting graph or only running forward process, no need to eliminate forward cnode.";
665     auto grad_exec = pynative_exec->grad_executor();
666     grad_exec->set_eliminate_forward(true);
667     return true;
668   }
669 
670   // Run grad process for func_graph and replace forward nodes with its output tensors.
671   MS_LOG(INFO) << "Run eliminate forward nodes action.";
672   MS_EXCEPTION_IF_NULL(res);
673   auto ms_func_graph = res->func_graph();
674   MS_EXCEPTION_IF_NULL(ms_func_graph);
675   auto grad_exec = pynative_exec->grad_executor();
676   bool eliminate_forward = grad_exec->eliminate_forward();
677   grad_exec->set_eliminate_forward(eliminate_forward && ms_func_graph->func_graphs_used().empty());
678   auto grad_graph = ad::Grad(ms_func_graph, res);
679   MS_EXCEPTION_IF_NULL(grad_graph);
680   graph_executor->SetGradGraph(grad_graph, phase);
681   ModifyOutputNode(ms_func_graph);
682 
683   // Keep roots for only keeping forward func graph in resource.
684   auto manager = res->manager();
685   MS_EXCEPTION_IF_NULL(manager);
686   manager->KeepRoots({ms_func_graph});
687 
688   grad_exec->set_eliminate_forward(true);
689   return true;
690 }
691 
TaskEmitAction(const ResourcePtr & res)692 bool TaskEmitAction(const ResourcePtr &res) {
693   MS_EXCEPTION_IF_NULL(res);
694   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
695       CheckGraphOutputConstOrParameter(res->func_graph())) {
696     return true;
697   }
698   if (res->func_graph() == nullptr) {
699     MS_LOG(EXCEPTION) << "TaskEmit args error";
700   }
701   // Disable mindRT in the control flow scenario.
702   ResetMindRTEnable(res);
703   FuncGraphPtr func_graph = res->func_graph();
704   MS_EXCEPTION_IF_NULL(func_graph);
705   auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
706   auto context_ptr = MsContext::GetInstance();
707   std::string backend = MsContext::GetInstance()->backend_policy();
708   MS_EXCEPTION_IF_NULL(context_ptr);
709   auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
710   if (func_graph->ContainMultiTarget() || !task_sink) {
711     bc_ptr->set_is_multi_graph_sink(false);
712     context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
713     context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
714   } else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
715     std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
716     auto manager = func_graph->manager();
717     auto graphs = manager->func_graphs();
718     bool exist_while =
719       std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
720     if (device_target == kAscendDevice && backend != kMsVm && !exist_while) {
721       MS_LOG(INFO) << "Run graph mode with multigraph sink.";
722       bc_ptr->set_is_multi_graph_sink(true);
723       context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
724     } else {
725       MS_LOG(INFO) << "Run graph mode with vm.";
726       bc_ptr->set_is_multi_graph_sink(false);
727       context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
728       context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
729     }
730   }
731 
732   // The graph compiling of mindRT.
733   if ((backend == kMsConvert) && context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
734     TaskEmitActionForMindRT(res);
735     return true;
736   }
737 
738   // The graph compiling of control sink.
739   if (IsCtrlSink() && backend == kMsConvert) {
740     res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
741     return true;
742   }
743   std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
744   if (bc_ptr->name() == kMsConvert) {
745     cut_list = compile::GetMsNonlinearOps();
746   }
747   std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
748   res->results()[kOutput] = compile->CompileAndLink(func_graph);
749   return true;
750 }
751 
ExecuteAction(const ResourcePtr & res)752 bool ExecuteAction(const ResourcePtr &res) {
753   MS_EXCEPTION_IF_NULL(res);
754   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
755       CheckGraphOutputConstOrParameter(res->func_graph())) {
756     return true;
757   }
758   if (res->results().count(kOutput) == 0) {
759     MS_LOG(EXCEPTION) << "Execute args error";
760   }
761   std::string backend = MsContext::GetInstance()->backend_policy();
762   // The graph running of mindRT.
763   if ((backend == kMsConvert) && MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
764     ExecuteActionForMindRT(res);
765     return true;
766   }
767 
768   // The graph running of control sink.
769   if (IsCtrlSink() && backend == kMsConvert) {
770     if (!res->results()[kOutput].is<GraphId>()) {
771       MS_LOG(EXCEPTION) << "Execute args error";
772     }
773     auto graph_id = res->results()[kOutput].cast<GraphId>();
774     std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
775     compile::MsBackend *msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr).get();
776     MS_EXCEPTION_IF_NULL(msbc_ptr);
777     compile::VmEvalFuncPtr run =
778       std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef {
779         MS_LOG(INFO) << "Execute args size " << args.size();
780         auto outs = msbc_ptr->RunGraph(graph_id, args);
781         MS_LOG(DEBUG) << "out size " << outs.size();
782         return outs[0];
783       });
784     res->results()[kOutput] = run;
785     return true;
786   }
787 
788   if (!res->results()[kOutput].is<compile::FinalVMPtr>()) {
789     MS_LOG(EXCEPTION) << "Execute args error";
790   }
791   compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>();
792   if (vm == nullptr) {
793     MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM";
794     return true;
795   }
796   compile::VmEvalFuncPtr run =
797     std::make_shared<compile::VmEvalFunc>(std::bind(&compile::FinalVM::Eval, vm, std::placeholders::_1));
798   res->results()[kOutput] = run;
799   return true;
800 }
801 
802 #if ((defined ENABLE_CPU) && (!defined _WIN32))
StartPSWorkerAction(const ResourcePtr &)803 bool StartPSWorkerAction(const ResourcePtr &) {
804   ps::Worker::GetInstance().Run();
805   return true;
806 }
StartFLWorkerAction(const ResourcePtr &)807 bool StartFLWorkerAction(const ResourcePtr &) {
808   fl::worker::FLWorker::GetInstance().Run();
809   return true;
810 }
811 
StartPSServerAction(const ResourcePtr & res)812 bool StartPSServerAction(const ResourcePtr &res) {
813   MS_EXCEPTION_IF_NULL(res);
814   FuncGraphPtr func_graph = res->func_graph();
815   auto &ps = ps::ParameterServer::GetInstance();
816   ps.Run(func_graph);
817   return true;
818 }
819 
StartServerAction(const ResourcePtr & res)820 bool StartServerAction(const ResourcePtr &res) {
821   MS_EXCEPTION_IF_NULL(res);
822   FuncGraphPtr func_graph = res->func_graph();
823   const std::string &server_mode_ = ps::PSContext::instance()->server_mode();
824   uint32_t worker_num = ps::PSContext::instance()->initial_worker_num();
825   uint32_t server_num = ps::PSContext::instance()->initial_server_num();
826   uint16_t fl_server_port = ps::PSContext::instance()->fl_server_port();
827 
828   // Update model threshold is a certain ratio of start_fl_job threshold.
829   // update_model_threshold = start_fl_job_threshold * update_model_ratio.
830   size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
831   float update_model_ratio = ps::PSContext::instance()->update_model_ratio();
832   size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * update_model_ratio));
833   uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
834   uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
835 
836   std::vector<fl::server::RoundConfig> rounds_config = {
837     {"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
838     {"updateModel", true, update_model_time_window, true, update_model_threshold},
839     {"getModel"},
840     {"pullWeight"},
841     {"pushWeight", false, 3000, true, server_num, true},
842     {"pushMetrics", false, 3000, true, 1}};
843 
844   float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
845   uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
846   size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold() + 1;
847 
848   size_t exchange_keys_threshold =
849     std::max(static_cast<size_t>(std::ceil(start_fl_job_threshold * share_secrets_ratio)), update_model_threshold);
850   size_t get_keys_threshold =
851     std::max(static_cast<size_t>(std::ceil(exchange_keys_threshold * share_secrets_ratio)), update_model_threshold);
852   size_t share_secrets_threshold =
853     std::max(static_cast<size_t>(std::ceil(get_keys_threshold * share_secrets_ratio)), update_model_threshold);
854   size_t get_secrets_threshold =
855     std::max(static_cast<size_t>(std::ceil(share_secrets_threshold * share_secrets_ratio)), update_model_threshold);
856   size_t client_list_threshold = std::max(static_cast<size_t>(std::ceil(update_model_threshold * share_secrets_ratio)),
857                                           reconstruct_secrets_threshold);
858 #ifdef ENABLE_ARMOUR
859   std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
860   if (encrypt_type == ps::kPWEncryptType) {
861     MS_LOG(INFO) << "Add secure aggregation rounds.";
862     rounds_config.push_back({"exchangeKeys", true, cipher_time_window, true, exchange_keys_threshold});
863     rounds_config.push_back({"getKeys", true, cipher_time_window, true, get_keys_threshold});
864     rounds_config.push_back({"shareSecrets", true, cipher_time_window, true, share_secrets_threshold});
865     rounds_config.push_back({"getSecrets", true, cipher_time_window, true, get_secrets_threshold});
866     rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold});
867     rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold});
868   }
869 #endif
870   fl::server::CipherConfig cipher_config = {
871     share_secrets_ratio,     cipher_time_window,    exchange_keys_threshold, get_keys_threshold,
872     share_secrets_threshold, get_secrets_threshold, client_list_threshold,   reconstruct_secrets_threshold};
873 
874   size_t executor_threshold = 0;
875   if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
876     executor_threshold = update_model_threshold;
877     fl::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph,
878                                                  executor_threshold);
879   } else if (server_mode_ == ps::kServerModePS) {
880     executor_threshold = worker_num;
881     fl::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph,
882                                                  executor_threshold);
883   } else {
884     MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
885     return false;
886   }
887   fl::server::Server::GetInstance().Run();
888   return true;
889 }
890 
StartPSSchedulerAction(const ResourcePtr &)891 bool StartPSSchedulerAction(const ResourcePtr &) {
892   ps::Scheduler::GetInstance().Run();
893   return true;
894 }
895 #endif
896 
897 // The parallel primitive related valuenode might be partitioned so that its value changes by device,
898 // that will result in a synchronization error due to different executing order.
899 // Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive,
900 // the final solution will be proposed later as a parallel feature.
KeepValueNodeDuplication(const AnfNodePtr & value_node,const ResourcePtr & res)901 bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) {
902   MS_EXCEPTION_IF_NULL(res);
903   MS_EXCEPTION_IF_NULL(res->manager());
904   auto &node_users = res->manager()->node_users();
905   auto &users = node_users[value_node];
906   auto used_by_keep_value_prim =
907     std::any_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
908       MS_EXCEPTION_IF_NULL(user.first);
909       auto cnode = user.first->cast<CNodePtr>();
910       if (cnode == nullptr) {
911         return false;
912       }
913       auto prim_node = cnode->input(0);
914       if (IsValueNode<Primitive>(prim_node)) {
915         auto prim = GetValue<PrimitivePtr>(prim_node->cast<ValueNodePtr>()->value());
916         MS_EXCEPTION_IF_NULL(prim);
917         // value_node is referenced by some parallel primitive
918         return prim->HasAttr("keep_value_node_input");
919       }
920       return false;
921     });
922   return used_by_keep_value_prim;
923 }
924 
RemoveValueNodeDuplicationsAction(const ResourcePtr & res)925 bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
926   MS_EXCEPTION_IF_NULL(res);
927   FuncGraphPtr func_graph = res->func_graph();
928   if (func_graph == nullptr) {
929     MS_LOG(EXCEPTION) << "Remove value node duplications error.";
930   }
931   auto manager = res->manager();
932   // Remove duplicated value nodes, due to replace operation, can't use reference.
933   auto value_nodes = func_graph->value_nodes();
934   HashCache hash_cache;
935   HashValue hashes;
936   for (const auto &value_pair : value_nodes) {
937     if (KeepValueNodeDuplication(value_pair.first, res)) {
938       continue;
939     }
940     TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
941   }
942   return true;
943 }
944 
PipelineSplitAction(const ResourcePtr & res)945 bool PipelineSplitAction(const ResourcePtr &res) { return PipelineSplitPass(res); }
ValidateAction(const ResourcePtr & res)946 bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
947 
SetMindIRGraphAction(const ResourcePtr & res)948 bool SetMindIRGraphAction(const ResourcePtr &res) {
949   MS_EXCEPTION_IF_NULL(res);
950   res->set_is_load(true);
951   auto cell = py::cast<CellPtr>(res->source_input());
952   if (cell == nullptr) {
953     MS_LOG(EXCEPTION) << "The graph loaded from mindir is null.";
954   }
955   const std::string mindir_graph = "graph_load_from_mindir";
956   auto obj = cell->GetAttr(mindir_graph);
957   if (obj == nullptr) {
958     MS_LOG(EXCEPTION) << "The graph loaded from mindir is null. The cell has not attribute: " << mindir_graph;
959   }
960   auto fg = GetValue<FuncGraphPtr>(obj);
961   if (fg == nullptr) {
962     MS_LOG(EXCEPTION) << "The graph loaded from mindir is null.";
963   }
964   res->set_func_graph(fg);
965   FuncGraphManagerPtr mng = fg->manager();
966   if (mng == nullptr) {
967     auto res_mng = res->manager();
968     MS_EXCEPTION_IF_NULL(res_mng);
969     res_mng->AddFuncGraph(fg);
970     fg->set_manager(res_mng);
971   }
972   abstract::AbstractBasePtrList broaded_args;
973   const auto &args_spec_list = res->args_spec();
974   (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_args),
975                        [](const AbstractBasePtr &arg) -> AbstractBasePtr {
976                          MS_EXCEPTION_IF_NULL(arg);
977                          if (arg->GetValueTrack() != kAnyValue) {
978                            return arg->Broaden();
979                          }
980                          return arg;
981                        });
982 
983   abstract::AbstractBasePtrList func_args;
984   const auto inputs = fg->get_inputs();
985   (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(func_args),
986                        [](const AnfNodePtr &arg) -> AbstractBasePtr {
987                          MS_EXCEPTION_IF_NULL(arg);
988                          return arg->abstract()->Broaden();
989                        });
990   if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) {
991     MS_LOG(EXCEPTION) << "The input arguments is not compatible with the function graph which has been exported before."
992                       << " Please check the args is same with export.\n"
993                       << "Export input args info:" << abstract::ArgsToString(func_args) << "\n"
994                       << "The input args info:" << abstract::ArgsToString(broaded_args);
995   }
996 
997   // suppose that there is not KeywordArgument for the top graph
998   // get the hyper parameter
999   for (const auto &param : fg->parameters()) {
1000     auto param_node = std::static_pointer_cast<Parameter>(param);
1001     MS_EXCEPTION_IF_NULL(param_node);
1002     if (param_node->has_default()) {
1003       auto value = param_node->default_param();
1004       MS_EXCEPTION_IF_NULL(value);
1005       auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
1006       auto ref_key = std::make_shared<RefKey>(param_node->name());
1007       auto abs_ref_key = ref_key->ToAbstract();
1008       auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
1009       broaded_args.push_back(abs_ref);
1010     }
1011   }
1012   (void)AbstractAnalyze(res, res->func_graph(), broaded_args, true);
1013   auto it = abstract::AnalysisResultCacheMgr::GetInstance().begin();
1014   auto it_end = abstract::AnalysisResultCacheMgr::GetInstance().end();
1015   for (; it != it_end; ++it) {
1016     it->first->node()->set_abstract(it->second->abstract());
1017   }
1018   abstract::AnalysisResultCacheMgr::GetInstance().Clear();
1019   return true;
1020 }
1021 
ActionPyStub(const ResourcePtr & res,opt::python_pass::Phase phase)1022 bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
1023   MS_EXCEPTION_IF_NULL(res->manager());
1024   MS_EXCEPTION_IF_NULL(res->func_graph());
1025   auto ppm = opt::python_pass::PyPassManager::GetInstance();
1026   ppm->SetResource(res);
1027   return ppm->GetPassGroup(phase)->Run(res->func_graph());
1028 }
1029 
PreAdActionPyStub(const ResourcePtr & res)1030 bool PreAdActionPyStub(const ResourcePtr &res) {
1031   if (!ActionPyStub(res, opt::python_pass::Phase::PREAD)) {
1032     MS_LOG(DEBUG) << "No Match.";
1033   }
1034   return true;
1035 }
1036 
OptActionVmPyStub(const ResourcePtr & res)1037 bool OptActionVmPyStub(const ResourcePtr &res) {
1038   if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
1039     if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
1040       // Renomalize
1041       FuncGraphPtr func_graph = res->func_graph();
1042       MS_EXCEPTION_IF_NULL(func_graph);
1043       abstract::AbstractBasePtrList args_spec;
1044       auto parameters = func_graph->parameters();
1045       (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
1046                            [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
1047       FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
1048       res->set_func_graph(new_fg);
1049       res->set_args_spec(args_spec);
1050     }
1051     if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
1052       return VmOptimizeAction(res);
1053     }
1054   }
1055   return true;
1056 }
1057 
OptActionGePyStub(const ResourcePtr & res)1058 bool OptActionGePyStub(const ResourcePtr &res) {
1059   if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
1060     if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
1061       // Renomalize
1062       FuncGraphPtr func_graph = res->func_graph();
1063       MS_EXCEPTION_IF_NULL(func_graph);
1064       abstract::AbstractBasePtrList args_spec;
1065       auto parameters = func_graph->parameters();
1066       (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
1067                            [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
1068       FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
1069       res->set_func_graph(new_fg);
1070       res->set_args_spec(args_spec);
1071     }
1072     if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
1073       return GeOptimizeAction(res);
1074     }
1075   }
1076   return true;
1077 }
1078 
CommonPipeline()1079 static std::vector<ActionItem> CommonPipeline() {
1080   std::vector<ActionItem> actions;
1081 
1082   // Parse the python ast to ANF graph
1083   (void)actions.emplace_back(std::make_pair("parse", ParseAction));
1084 
1085   // Resolve the python func
1086   (void)actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
1087 
1088   auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs();
1089   if (!multi_graphs) {
1090     (void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
1091   }
1092 
1093   (void)actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
1094   // Evaluate type and shape, and specialize
1095   (void)actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
1096   // Auto-monad for side-effects handling.
1097   (void)actions.emplace_back(std::make_pair("auto_monad", AutoMonadAction));
1098   // Do data structure simplifications and inline
1099   (void)actions.emplace_back(std::make_pair("inline", OptInlineAction));
1100   // Add pre-ad, post-inline python pass stub
1101   (void)actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub));
1102   // Do PipelineSplit
1103   (void)actions.emplace_back(std::make_pair("pipeline_split", PipelineSplitAction));
1104 
1105   return actions;
1106 }
1107 
GePipeline()1108 std::vector<ActionItem> GePipeline() {
1109   auto actions = CommonPipeline();
1110   // optimize
1111   (void)actions.emplace_back(std::make_pair("optimize", GeOptimizeAction));
1112   // Add opt-stage python pass stub
1113   (void)actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub));
1114   (void)actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
1115   (void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
1116   (void)actions.emplace_back(std::make_pair("validate", ValidateAction));
1117   return actions;
1118 }
1119 
VmPipeline()1120 std::vector<ActionItem> VmPipeline() {
1121   auto actions = CommonPipeline();
1122 
1123   // optimize
1124   (void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
1125 
1126   // Add opt-stage python pass stub
1127   (void)actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub));
1128 
1129   (void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
1130 
1131   // eliminate forward cnode for grad graph
1132   (void)actions.emplace_back(std::make_pair("eliminate_forward_cnode", EliminateForwardCNode));
1133 
1134   (void)actions.emplace_back(std::make_pair("validate", ValidateAction));
1135 #if ((defined ENABLE_CPU) && (!defined _WIN32))
1136   if (ps::PSContext::instance()->is_worker()) {
1137     std::string server_mode = ps::PSContext::instance()->server_mode();
1138     if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) {
1139       (void)actions.emplace_back(std::make_pair("worker", StartFLWorkerAction));
1140     } else {
1141       (void)actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
1142     }
1143   }
1144 #endif
1145   // compile the ANF graph
1146   (void)actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
1147 
1148   // to execute the graph
1149   (void)actions.emplace_back(std::make_pair("execute", ExecuteAction));
1150 
1151   return actions;
1152 }
1153 
BackendPipeline()1154 std::vector<ActionItem> BackendPipeline() {
1155   std::vector<ActionItem> actions;
1156   // compile the ANF graph
1157   (void)actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
1158   // to execute the graph
1159   (void)actions.emplace_back(std::make_pair("execute", ExecuteAction));
1160   return actions;
1161 }
MindIRPipeline()1162 std::vector<ActionItem> MindIRPipeline() {
1163   std::vector<ActionItem> actions;
1164   // Set funcGraph loaded from MindIR to resource.
1165   (void)actions.emplace_back(std::make_pair("load_mindir", SetMindIRGraphAction));
1166   (void)actions.emplace_back(std::make_pair("validate", ValidateAction));
1167   // compile the ANF graph
1168   (void)actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
1169   // to execute the graph
1170   (void)actions.emplace_back(std::make_pair("execute", ExecuteAction));
1171   return actions;
1172 }
1173 #if ((defined ENABLE_CPU) && (!defined _WIN32))
ServerPipeline()1174 std::vector<ActionItem> ServerPipeline() {
1175   auto actions = CommonPipeline();
1176   (void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
1177   (void)actions.emplace_back(std::make_pair("validate", ValidateAction));
1178   (void)actions.emplace_back(std::make_pair("server", StartServerAction));
1179   return actions;
1180 }
1181 
PServerPipeline()1182 std::vector<ActionItem> PServerPipeline() {
1183   auto actions = CommonPipeline();
1184   (void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
1185   (void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
1186   (void)actions.emplace_back(std::make_pair("validate", ValidateAction));
1187   (void)actions.emplace_back(std::make_pair("pserver", StartPSServerAction));
1188   return actions;
1189 }
1190 
PSchedulerPipeline()1191 std::vector<ActionItem> PSchedulerPipeline() {
1192   std::vector<ActionItem> actions;
1193   (void)actions.emplace_back(std::make_pair("scheduler", StartPSSchedulerAction));
1194   return actions;
1195 }
1196 #endif
1197 }  // namespace pipeline
1198 }  // namespace mindspore
1199