1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/action.h"
18
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 #include <string>
23 #include <algorithm>
24 #include <functional>
25
26 #include "ir/func_graph_cloner.h"
27 #include "ir/param_info.h"
28 #include "ir/cell.h"
29 #include "parse/python_adapter.h"
30 #include "abstract/abstract_value.h"
31 #include "frontend/parallel/costmodel_context.h"
32 #include "frontend/parallel/context.h"
33 #include "pipeline/jit/pass.h"
34 #include "pipeline/jit/parse/parse_base.h"
35 #include "pipeline/jit/parse/data_converter.h"
36 #include "pipeline/jit/static_analysis/auto_monad.h"
37 #include "pipeline/jit/static_analysis/order_enforce.h"
38 #include "pipeline/jit/static_analysis/static_analysis.h"
39 #include "pipeline/jit/static_analysis/async_eval_result.h"
40 #include "pipeline/jit/static_analysis/program_specialize.h"
41 #include "pipeline/jit/resource.h"
42 #include "pipeline/jit/remove_value_node_dup.h"
43 #include "pipeline/pynative/pynative_execute.h"
44 #include "frontend/optimizer/optimizer.h"
45 #include "frontend/optimizer/ad/grad.h"
46 #include "frontend/optimizer/py_pass_manager.h"
47 #include "utils/ms_context.h"
48 #include "vm/transform.h"
49 #if ((defined ENABLE_CPU) && (!defined _WIN32))
50 #include "ps/parameter_server.h"
51 #include "ps/scheduler.h"
52 #include "ps/worker.h"
53 #include "fl/worker/fl_worker.h"
54 #include "fl/server/server.h"
55 #endif
56
57 namespace mindspore {
58 namespace pipeline {
59 namespace {
UpdateFuncGraphParameter(const FuncGraphPtr & func_graph)60 void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph) {
61 MS_EXCEPTION_IF_NULL(func_graph);
62 std::vector<AnfNodePtr> new_paras;
63 for (const auto ¶m : func_graph->parameters()) {
64 auto param_node = param->cast<ParameterPtr>();
65 MS_EXCEPTION_IF_NULL(param_node);
66 if (param_node->has_default()) {
67 new_paras.push_back(param_node);
68 continue;
69 }
70 AbstractBasePtr par_abs = param_node->abstract();
71 MS_EXCEPTION_IF_NULL(par_abs);
72 if (par_abs->isa<abstract::AbstractUndetermined>() ||
73 (MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && par_abs->BuildType() != nullptr &&
74 par_abs->BuildType()->isa<Number>())) {
75 new_paras.push_back(param_node);
76 }
77 }
78 func_graph->set_parameters(new_paras);
79 }
80
81 // Disable mindRT in the control flow scenario.
ResetMindRTEnable(const ResourcePtr & res)82 void ResetMindRTEnable(const ResourcePtr &res) {
83 MS_EXCEPTION_IF_NULL(res);
84 auto context_ptr = MsContext::GetInstance();
85 MS_EXCEPTION_IF_NULL(context_ptr);
86 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT) == false) {
87 return;
88 }
89
90 auto func_graph = res->func_graph();
91 MS_EXCEPTION_IF_NULL(func_graph);
92 if (func_graph != nullptr && func_graph->manager() != nullptr) {
93 auto manager = func_graph->manager();
94 size_t graph_nums = manager->func_graphs().size();
95 if (graph_nums == 1) {
96 return;
97 }
98
99 MS_LOG(INFO) << "Disable mindRT in the multi graphs scenario.";
100 context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
101 // Update the backend.
102 auto new_backend = compile::CreateBackend();
103 new_backend->SetDebugger();
104 res->results()[kBackend] = new_backend;
105 }
106 }
107
TaskEmitActionForMindRT(const ResourcePtr & res)108 void TaskEmitActionForMindRT(const ResourcePtr &res) {
109 MS_EXCEPTION_IF_NULL(res);
110 // Get the mindRT backend.
111 auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
112 auto mindrt_bc_ptr = std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr);
113 MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
114
115 // The output of graph compiler is actor.
116 res->results()[kOutput] = mindrt_bc_ptr->CompileGraphs(res->func_graph());
117 }
118
ExecuteActionForMindRT(const ResourcePtr & res)119 void ExecuteActionForMindRT(const ResourcePtr &res) {
120 MS_EXCEPTION_IF_NULL(res);
121 if (!res->results()[kOutput].is<compile::ActorInfo>()) {
122 MS_LOG(EXCEPTION) << "Execute args error";
123 }
124 const auto &actor_info = res->results()[kOutput].cast<compile::ActorInfo>();
125
126 // Get the mindRT backend.
127 std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
128 auto mindrt_bc_ptr = (std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr)).get();
129 MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
130
131 // Construct the graph run function ptr.
132 compile::VmEvalFuncPtr run =
133 std::make_shared<compile::VmEvalFunc>([mindrt_bc_ptr, actor_info](const VectorRef &args) -> BaseRef {
134 MS_LOG(DEBUG) << "Execute args size " << args.size();
135 VectorRef outputs;
136 mindrt_bc_ptr->RunGraph(actor_info, args, &outputs);
137 MS_LOG(DEBUG) << "out size " << outputs.size();
138 return outputs[0];
139 });
140 res->results()[kOutput] = run;
141 }
142
143 // Modify the output node of func_graph to add forward nodes used in bprop graph.
ModifyOutputNode(const FuncGraphPtr & func_graph)144 void ModifyOutputNode(const FuncGraphPtr &func_graph) {
145 MS_EXCEPTION_IF_NULL(func_graph);
146 const auto &used_forward_nodes = func_graph->used_forward_nodes();
147
148 // Get original output node and abstract
149 auto original_output_node = func_graph->output();
150 MS_EXCEPTION_IF_NULL(original_output_node);
151 auto original_output_abs = original_output_node->abstract();
152 MS_EXCEPTION_IF_NULL(original_output_abs);
153
154 // Create a new make tuple node to hold all forward used nodes.
155 abstract::AbstractBasePtrList added_abs_list;
156 std::vector<AnfNodePtr> added_node_list{NewValueNode(prim::kPrimMakeTuple)};
157 std::for_each(used_forward_nodes.begin(), used_forward_nodes.end(),
158 [&added_abs_list, &added_node_list](const AnfNodePtr &node) {
159 MS_EXCEPTION_IF_NULL(node);
160 added_node_list.push_back(node);
161 added_abs_list.push_back(node->abstract());
162 });
163 AnfNodePtr added_output_node = nullptr;
164 AbstractBasePtr added_output_abs = nullptr;
165 if (added_abs_list.empty()) {
166 added_output_node = NewValueNode(MakeValue<int32_t>(1));
167 added_output_abs = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int32Imm>(1));
168 } else {
169 added_output_node = func_graph->NewCNode(added_node_list);
170 added_output_abs = std::make_shared<abstract::AbstractTuple>(added_abs_list);
171 }
172 added_output_node->set_abstract(added_output_abs);
173 MS_LOG(DEBUG) << "Added output node info: " << added_output_node->DebugString();
174
175 // Merge original output node and used forward nodes to return node.
176 std::vector<AnfNodePtr> new_output_nodes{NewValueNode(prim::kPrimMakeTuple), original_output_node, added_output_node};
177 auto merge_node = func_graph->NewCNode(new_output_nodes);
178 abstract::AbstractBasePtrList new_output_abs{original_output_abs, added_output_abs};
179 merge_node->set_abstract(std::make_shared<abstract::AbstractTuple>(new_output_abs));
180 MS_LOG(DEBUG) << "Merge node info: " << merge_node->DebugString();
181 func_graph->set_output(merge_node);
182
183 // Clear
184 func_graph->set_modify_output(true);
185 func_graph->ClearUsedForwardNodes();
186 }
187 } // namespace
188 using CompileGraphs = compile::CompileGraphs;
189 using abstract::AnalysisResult;
190 using mindspore::abstract::AnalysisContextPtr;
191
AbstractAnalyze(const ResourcePtr & res,const FuncGraphPtr & func_graph,const abstract::AbstractBasePtrList & args_spec,bool clear)192 abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph,
193 const abstract::AbstractBasePtrList &args_spec, bool clear) {
194 MS_LOG(DEBUG) << "AbstractAnalyze start";
195 auto engine = res->engine();
196 MS_EXCEPTION_IF_NULL(engine);
197 if (clear) {
198 auto manager = res->manager();
199 MS_EXCEPTION_IF_NULL(manager);
200 engine->Clear();
201 for (auto &node : manager->all_nodes()) {
202 MS_EXCEPTION_IF_NULL(node);
203
204 // Handle previous inferred value for CNode if is loaded from MindIR
205 if (res->is_load()) {
206 // If the primitive is not defined in front end,keep the inferred value loaded from MindIR.
207 auto primitive = GetCNodePrimitive(node);
208 if (primitive != nullptr && abstract::GetPrimEvaluator(primitive, engine) == nullptr) {
209 MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
210 continue;
211 }
212 }
213
214 const AbstractBasePtr &prev_inferred = node->abstract();
215 // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
216 if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
217 node->set_abstract(nullptr);
218 MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
219 }
220 }
221 }
222 auto ret = engine->Run(func_graph, args_spec);
223 MS_LOG(INFO) << "function call max depth: " << abstract::FunctionCallMaxDepth()
224 << ", simulate call max depth: " << abstract::StackFrameMaxDepth();
225 MS_LOG(DEBUG) << "AbstractAnalyze end";
226 return ret;
227 }
228
ProgramSpecialize(const ResourcePtr & res,const FuncGraphPtr & func_graph,const abstract::AnalysisContextPtr & context)229 FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
230 const abstract::AnalysisContextPtr &context) {
231 MS_EXCEPTION_IF_NULL(res);
232 MS_LOG(DEBUG) << "ProgramSpecialize start";
233 abstract::ProgramSpecializer spc(res->engine());
234 FuncGraphPtr result = spc.Run(func_graph, context);
235 auto manager = res->manager();
236 MS_EXCEPTION_IF_NULL(manager);
237 manager->KeepRoots({result});
238 MS_LOG(DEBUG) << "ProgramSpecialize end";
239 return result;
240 }
241
Renormalize(const ResourcePtr & res,const FuncGraphPtr & func_graph,const abstract::AbstractBasePtrList & args_spec)242 FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph,
243 const abstract::AbstractBasePtrList &args_spec) {
244 MS_EXCEPTION_IF_NULL(res);
245 MS_LOG(DEBUG) << "Renormalize start";
246 #ifdef ENABLE_PROFILE
247 double t1 = GetTime();
248 #endif
249 abstract::AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec, true);
250 #ifdef ENABLE_PROFILE
251 double t2 = GetTime();
252 #endif
253 auto ret = ProgramSpecialize(res, func_graph, result.context);
254 res->set_func_graph(ret);
255 #ifdef ENABLE_PROFILE
256 double t3 = GetTime();
257 MsProfile::StatTime("renormalize.infer", t2 - t1);
258 MsProfile::StatTime("renormalize.specialize", t3 - t2);
259 #endif
260
261 MS_LOG(DEBUG) << "Renormalize end";
262
263 return ret;
264 }
265
GetLoadedGraph(const ResourcePtr & res)266 const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) {
267 MS_EXCEPTION_IF_NULL(res);
268 auto manager = res->manager();
269 MS_EXCEPTION_IF_NULL(manager);
270 FuncGraphPtr loaded_graph = nullptr;
271 size_t loaded_graph_num = 0;
272 auto all_graphs = manager->func_graphs();
273 for (auto &graph : all_graphs) {
274 MS_EXCEPTION_IF_NULL(graph);
275 if (graph->has_attr("is_load")) {
276 loaded_graph = graph;
277 loaded_graph_num += 1;
278 res->set_is_load(true);
279 }
280 }
281 if (loaded_graph_num == 0) {
282 return nullptr;
283 }
284 if (loaded_graph_num == 1) {
285 return loaded_graph;
286 }
287 MS_LOG(EXCEPTION) << "The loaded sub graph currently should less than 2, but got " << loaded_graph_num;
288 }
289
CheckRootInputShapeAndType(const ResourcePtr & res,const FuncGraphPtr & loaded_graph)290 void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &loaded_graph) {
291 MS_EXCEPTION_IF_NULL(res);
292 auto manager = res->manager();
293 MS_EXCEPTION_IF_NULL(manager);
294 FuncGraphPtr root_graph = *(manager->roots().begin());
295 auto root_inputs = root_graph->get_inputs();
296 auto loaded_inputs = loaded_graph->get_inputs();
297 MS_LOG(DEBUG) << "root_graph: " << root_graph->ToString();
298 MS_LOG(DEBUG) << "loaded_graph: " << loaded_graph->ToString();
299 size_t root_inputs_num = root_inputs.size();
300 size_t loaded_inputs_num = loaded_inputs.size();
301 if (root_inputs_num != loaded_inputs_num) {
302 MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph "
303 << loaded_inputs_num;
304 }
305 for (size_t index = 0; index < root_inputs_num; index++) {
306 auto root_input = root_inputs[index];
307 auto loaded_input = loaded_inputs[index];
308
309 MS_LOG(DEBUG) << "root_input[" << index << "]: " << root_input->DebugString(1);
310 MS_LOG(DEBUG) << "loaded_input[" << index << "]: " << loaded_input->DebugString(1);
311 MS_LOG(DEBUG) << "root_input abstract[" << index
312 << "]: " << (root_input->abstract() ? root_input->abstract()->ToString() : "NULL");
313 MS_LOG(DEBUG) << "loaded_input abstract [" << index
314 << "]: " << (loaded_input->abstract() ? loaded_input->abstract()->ToString() : "NULL");
315
316 auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(root_input->Shape());
317 auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast<abstract::Shape>(loaded_input->Shape());
318 auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast<Type>(root_input->Type());
319 auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast<Type>(loaded_input->Type());
320
321 MS_EXCEPTION_IF_NULL(root_shape);
322 MS_EXCEPTION_IF_NULL(loaded_shape);
323 MS_EXCEPTION_IF_NULL(root_type);
324 MS_EXCEPTION_IF_NULL(loaded_type);
325
326 auto shapeEqu = (root_shape->shape() == loaded_shape->shape()) ||
327 (root_shape->shape().size() <= 1 && loaded_shape->shape().size() <= 1);
328 if (!shapeEqu) {
329 MS_EXCEPTION(ValueError) << "The " << index
330 << " th input shape differ from loaded graph. Input shape: " << root_shape->ToString()
331 << ", input shape of loaded graph: " << loaded_shape->ToString();
332 }
333 if (root_type->type_id() != loaded_type->type_id()) {
334 MS_EXCEPTION(TypeError) << "The " << std::to_string(index)
335 << " th input type differ from loaded graph. Input type: " << root_type->ToString()
336 << ", input type of loaded graph: " << loaded_type->ToString();
337 }
338 }
339 }
340
ParseAction(const ResourcePtr & res)341 bool ParseAction(const ResourcePtr &res) {
342 MS_EXCEPTION_IF_NULL(res);
343 if (!res->source_input()) {
344 MS_LOG(EXCEPTION) << "Parse error";
345 }
346
347 py::object input = res->source_input();
348 parse::Parser::InitParserEnvironment(input);
349 py::module path = py::module::import("os.path");
350 std::string dir = path.attr("dirname")(py::globals()["__file__"]).cast<std::string>();
351
352 parse::python_adapter::set_python_env_flag(true);
353 parse::python_adapter::SetPythonPath(dir);
354
355 ValuePtr converted_ret = nullptr;
356 bool converted = parse::ConvertData(input, &converted_ret, true);
357 if (!converted) {
358 MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(input));
359 }
360
361 FuncGraphPtr top_graph = nullptr;
362 if (py::isinstance<Cell>(input)) {
363 top_graph = parse::MakeTopGraph(input, converted_ret);
364 } else if (converted_ret->isa<FuncGraph>()) {
365 top_graph = converted_ret->cast<FuncGraphPtr>();
366 } else {
367 MS_LOG(EXCEPTION) << "Object to parse " << std::string(py::str(input)) << " is not function or cell.";
368 }
369 parse::Parser::UpdateTopFuncGraph(top_graph);
370
371 res->set_func_graph(top_graph);
372
373 FuncGraphManagerPtr manager = res->manager();
374 if (manager == nullptr) {
375 MS_LOG(EXCEPTION) << "Manager is nullptr.";
376 }
377 manager->AddFuncGraph(top_graph);
378 return true;
379 }
380
381 // obj_map's graphs have the same construct, these graphs can be optimized to one graph.
382 // This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}->
383 // graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx}
384 // all obj_map's graph shared base_graph
CombineLikeGraphs(const ResourcePtr & res)385 bool CombineLikeGraphs(const ResourcePtr &res) {
386 MS_EXCEPTION_IF_NULL(res);
387 auto &obj_map = parse::data_converter::GetObjGraphs();
388 for (auto it : obj_map) {
389 auto &graphs = it.second;
390 MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size();
391 auto fg = graphs[0];
392 FuncGraphVector func_graphs = {fg};
393 ClonerPtr cloner = std::make_shared<Cloner>(func_graphs, false, false, true, std::make_shared<TraceCopy>(),
394 std::make_shared<TraceCombileLikeGraphs>());
395 cloner->Run();
396 auto base_graph = cloner->cloned_func_graph()[fg];
397 MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
398
399 if (fg->paramter_obj_nodes().empty() || graphs.size() <= 1 || fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE)) {
400 continue;
401 }
402 auto &cloned_nodes = *cloner->cloned_node();
403 for (auto &fv : fg->paramter_obj_nodes()) {
404 TraceGuard guard(std::make_shared<TraceCombileLikeGraphs>(fv->debug_info()));
405 auto param = base_graph->add_parameter();
406 MS_EXCEPTION_IF_NULL(res->manager());
407 auto &node_users = res->manager()->node_users()[fv];
408 for (auto &n : node_users) {
409 // If the user is not in this graph, no need to change.
410 auto cloned = cloned_nodes[n.first];
411 if (cloned == nullptr) {
412 continue;
413 }
414 auto repl_n = cloned->cast<CNodePtr>();
415 MS_EXCEPTION_IF_NULL(repl_n);
416 repl_n->set_input(IntToSize(n.second), param);
417 }
418 }
419 MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size();
420
421 for (auto &g : graphs) {
422 auto &fvs = g->paramter_obj_nodes();
423 std::vector<AnfNodePtr> new_node_inputs;
424 new_node_inputs.push_back(NewValueNode(base_graph));
425 for (auto &p : g->parameters()) {
426 AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p);
427 new_node_inputs.push_back(para_after_cast);
428 }
429 (void)new_node_inputs.insert(new_node_inputs.end(), fvs.begin(), fvs.end());
430 AnfNodePtr out = g->NewCNodeBefore(g->get_return(), new_node_inputs);
431 g->set_output(out);
432 const int recursive_level = 4;
433 MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(recursive_level);
434 }
435 MS_LOG(DEBUG) << "End combine graph:" << it.first;
436 }
437 return true;
438 }
439
SymbolResolveAction(const ResourcePtr & res)440 bool SymbolResolveAction(const ResourcePtr &res) {
441 MS_EXCEPTION_IF_NULL(res);
442 if (res->manager() == nullptr) {
443 MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null";
444 }
445 auto func_graph = res->func_graph();
446 if (func_graph == nullptr) {
447 MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null";
448 }
449 bool ret = parse::ResolveFuncGraph(func_graph, res);
450 // Remove unused nodes in cnode order list.
451 if (func_graph) {
452 func_graph->EraseUnusedNodeInOrder();
453 for (auto fg : func_graph->func_graphs_used_total()) {
454 if (fg) {
455 fg->EraseUnusedNodeInOrder();
456 }
457 }
458 }
459 return ret;
460 }
461
AutoMonadAction(const ResourcePtr & res)462 bool AutoMonadAction(const ResourcePtr &res) {
463 MS_EXCEPTION_IF_NULL(res);
464 if (res->manager() == nullptr) {
465 MS_LOG(EXCEPTION) << "Auto-Monad failed, manager is null";
466 }
467 auto func_graph = res->func_graph();
468 if (func_graph == nullptr) {
469 MS_LOG(EXCEPTION) << "Auto-Monad failed, graph is null";
470 }
471 (void)pipeline::AutoMonad(func_graph);
472 return true;
473 }
474
OrderEnforceAction(const ResourcePtr & res)475 bool OrderEnforceAction(const ResourcePtr &res) {
476 MS_EXCEPTION_IF_NULL(res);
477 if (res->manager() == nullptr) {
478 MS_LOG(EXCEPTION) << "Order-Enforce error, manager is null";
479 }
480 auto func_graph = res->func_graph();
481 if (func_graph == nullptr) {
482 MS_LOG(EXCEPTION) << "Order-Enforce error, graph is null";
483 }
484 pipeline::OrderEnforce(func_graph);
485 return true;
486 }
487
InferenceOptPrepareAction(const ResourcePtr & res)488 bool InferenceOptPrepareAction(const ResourcePtr &res) {
489 MS_EXCEPTION_IF_NULL(res);
490 if (res->manager() == nullptr) {
491 MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null.";
492 }
493 if (res->func_graph() == nullptr) {
494 MS_LOG(EXCEPTION) << "InferenceOptPrepare error, graph is null.";
495 }
496 return InferenceOptPreparePass(res);
497 }
498
AbstractSpecializeAction(const ResourcePtr & res)499 bool AbstractSpecializeAction(const ResourcePtr &res) {
500 MS_EXCEPTION_IF_NULL(res);
501 if (res->func_graph() == nullptr) {
502 MS_LOG(EXCEPTION) << "AbstractSpecialize error";
503 }
504 FuncGraphPtr func_graph = res->func_graph();
505 abstract::AbstractBasePtrList args_spec = res->args_spec();
506 auto context = parallel::ParallelContext::GetInstance();
507 MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
508 context->ParallelParameterContextInitShape(func_graph);
509
510 // Get original loaded graph to check inputs later
511 auto loaded_graph_ptr = GetLoadedGraph(res);
512 // suppose that there is not KeywordArgument for the top graph
513 // get the hyper parameter
514 for (const auto ¶m : func_graph->parameters()) {
515 auto param_node = std::static_pointer_cast<Parameter>(param);
516 MS_EXCEPTION_IF_NULL(param_node);
517 if (param_node->has_default()) {
518 auto value = param_node->default_param();
519 MS_EXCEPTION_IF_NULL(value);
520 auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
521 auto ref_key = std::make_shared<RefKey>(param_node->name());
522 auto abs_ref_key = ref_key->ToAbstract();
523 auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
524 context->ParallelParameterContextRestoreShape(func_graph, param_node, abs_ref);
525 args_spec.push_back(abs_ref);
526 context->ParallelParameterContextCkptShape(func_graph, param_node, abs_ref);
527 }
528 }
529 // Analyze
530 AnalysisResult result = AbstractAnalyze(res, func_graph, args_spec);
531
532 // The top graph may be replaced by infer, update the top graph when the infer is done
533 parse::Parser::UpdateTopFuncGraph(result.context->func_graph());
534
535 // Specialize
536 FuncGraphPtr new_fg = ProgramSpecialize(res, result.context->func_graph(), result.context);
537 res->set_func_graph(new_fg);
538
539 // Remove unused nodes in cnode order list, this is prepared for auto-monad.
540 if (new_fg) {
541 new_fg->EraseUnusedNodeInOrder();
542 for (auto fg : new_fg->func_graphs_used_total()) {
543 if (fg) {
544 fg->EraseUnusedNodeInOrder();
545 }
546 }
547 }
548 // Check input after abstract when there is a loaded graph
549 if (loaded_graph_ptr != nullptr) {
550 CheckRootInputShapeAndType(res, loaded_graph_ptr);
551 }
552
553 UpdateFuncGraphParameter(new_fg);
554 MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
555 return true;
556 }
557
OptimizeAction(const ResourcePtr & res,const std::vector<PassItem> & passes)558 bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes) {
559 MS_EXCEPTION_IF_NULL(res);
560 size_t counter = 0;
561 for (auto &pass : passes) {
562 WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res, &counter]() {
563 MS_LOG(DEBUG) << "Pass " << pass.first << " start ...";
564 auto result = pass.second(res);
565 if (!result) {
566 MS_LOG(EXCEPTION) << "Pass running to end, failed in pass:" << pass.first;
567 }
568 #ifdef ENABLE_DUMP_IR
569 if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG) && res->func_graph() != nullptr) {
570 auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first;
571 auto func_graph = res->func_graph();
572 MS_EXCEPTION_IF_NULL(func_graph);
573 func_graph->DumpFuncGraph(fg_name);
574 DumpIR(fg_name + ".ir", func_graph);
575 ExportIR(fg_name + ".dat", func_graph);
576 MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
577 }
578 #endif
579 counter++;
580 MS_LOG(DEBUG) << "Pass " << pass.first << " end.";
581 };
582 }
583
584 return true;
585 }
586
OptInlineAction(const ResourcePtr & res)587 bool OptInlineAction(const ResourcePtr &res) {
588 if (parallel::ParallelContext::GetInstance()->parallel_mode() == "semi_auto_parallel" ||
589 parallel::ParallelContext::GetInstance()->parallel_mode() == "auto_parallel") {
590 return OptimizeAction(res, kInlinePasses);
591 }
592 if (opt::python_pass::PyPassManager::GetInstance()->GetPassGroup(opt::python_pass::Phase::PREAD)->size() != 0) {
593 return OptimizeAction(res, kInlinePasses);
594 }
595 return true;
596 }
597
GeOptimizeAction(const ResourcePtr & res)598 bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }
599
VmOptimizeAction(const ResourcePtr & res)600 bool VmOptimizeAction(const ResourcePtr &res) {
601 #if ((defined ENABLE_CPU) && (!defined _WIN32))
602 if (ps::PSContext::instance()->is_ps_mode()) {
603 kVmPasses.push_back({"server_communication_op_fusion", ps::Util::FuseServerCommOps});
604 }
605 #endif
606 return OptimizeAction(res, kVmPasses);
607 }
608
PynativeElimOpt(const ResourcePtr & res)609 bool PynativeElimOpt(const ResourcePtr &res) {
610 MS_EXCEPTION_IF_NULL(res);
611 if (res->manager() == nullptr) {
612 MS_LOG(EXCEPTION) << "PynativeElimOpt error, manager is null.";
613 }
614 if (res->func_graph() == nullptr) {
615 MS_LOG(EXCEPTION) << "PynativeElimOpt error, graph is null.";
616 }
617 return PynativeOptPass(res);
618 }
619
IsCtrlSink()620 static bool IsCtrlSink() {
621 auto ms_ctx = MsContext::GetInstance();
622 if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) {
623 return false;
624 }
625
626 std::string device_target = ms_ctx->get_param<std::string>(MS_CTX_DEVICE_TARGET);
627 if (device_target != kAscendDevice) {
628 return false;
629 }
630
631 if (!ms_ctx->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
632 return false;
633 }
634
635 if (!ms_ctx->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK)) {
636 return false;
637 }
638 return true;
639 }
640
CheckGraphOutputConstOrParameter(const FuncGraphPtr & func_graph)641 bool CheckGraphOutputConstOrParameter(const FuncGraphPtr &func_graph) {
642 if (func_graph != nullptr) {
643 AnfNodePtr output = func_graph->output();
644 if (output != nullptr && (output->isa<ValueNode>() || output->isa<Parameter>())) {
645 return true;
646 }
647 }
648 return false;
649 }
650
EliminateForwardCNode(const ResourcePtr & res)651 bool EliminateForwardCNode(const ResourcePtr &res) {
652 // This function only works in Pynative mode. The func_graph is decorated by ms_function.
653 if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
654 return true;
655 }
656
657 auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
658 MS_EXCEPTION_IF_NULL(graph_executor);
659 auto phase = graph_executor->phase();
660 MS_LOG(DEBUG) << "The phase of current pipeline graph is: " << phase;
661 // Exporting graph in PyNative mode or only running forward process no need to do this action.
662 auto pynative_exec = pynative::PynativeExecutor::GetInstance();
663 if (phase.find("export") == 0 || !pynative_exec->grad_flag()) {
664 MS_LOG(DEBUG) << "When exporting graph or only running forward process, no need to eliminate forward cnode.";
665 auto grad_exec = pynative_exec->grad_executor();
666 grad_exec->set_eliminate_forward(true);
667 return true;
668 }
669
670 // Run grad process for func_graph and replace forward nodes with its output tensors.
671 MS_LOG(INFO) << "Run eliminate forward nodes action.";
672 MS_EXCEPTION_IF_NULL(res);
673 auto ms_func_graph = res->func_graph();
674 MS_EXCEPTION_IF_NULL(ms_func_graph);
675 auto grad_exec = pynative_exec->grad_executor();
676 bool eliminate_forward = grad_exec->eliminate_forward();
677 grad_exec->set_eliminate_forward(eliminate_forward && ms_func_graph->func_graphs_used().empty());
678 auto grad_graph = ad::Grad(ms_func_graph, res);
679 MS_EXCEPTION_IF_NULL(grad_graph);
680 graph_executor->SetGradGraph(grad_graph, phase);
681 ModifyOutputNode(ms_func_graph);
682
683 // Keep roots for only keeping forward func graph in resource.
684 auto manager = res->manager();
685 MS_EXCEPTION_IF_NULL(manager);
686 manager->KeepRoots({ms_func_graph});
687
688 grad_exec->set_eliminate_forward(true);
689 return true;
690 }
691
TaskEmitAction(const ResourcePtr & res)692 bool TaskEmitAction(const ResourcePtr &res) {
693 MS_EXCEPTION_IF_NULL(res);
694 if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
695 CheckGraphOutputConstOrParameter(res->func_graph())) {
696 return true;
697 }
698 if (res->func_graph() == nullptr) {
699 MS_LOG(EXCEPTION) << "TaskEmit args error";
700 }
701 // Disable mindRT in the control flow scenario.
702 ResetMindRTEnable(res);
703 FuncGraphPtr func_graph = res->func_graph();
704 MS_EXCEPTION_IF_NULL(func_graph);
705 auto bc_ptr = res->results()[kBackend].cast<compile::BackendPtr>();
706 auto context_ptr = MsContext::GetInstance();
707 std::string backend = MsContext::GetInstance()->backend_policy();
708 MS_EXCEPTION_IF_NULL(context_ptr);
709 auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
710 if (func_graph->ContainMultiTarget() || !task_sink) {
711 bc_ptr->set_is_multi_graph_sink(false);
712 context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
713 context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
714 } else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
715 std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
716 auto manager = func_graph->manager();
717 auto graphs = manager->func_graphs();
718 bool exist_while =
719 std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
720 if (device_target == kAscendDevice && backend != kMsVm && !exist_while) {
721 MS_LOG(INFO) << "Run graph mode with multigraph sink.";
722 bc_ptr->set_is_multi_graph_sink(true);
723 context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
724 } else {
725 MS_LOG(INFO) << "Run graph mode with vm.";
726 bc_ptr->set_is_multi_graph_sink(false);
727 context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
728 context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
729 }
730 }
731
732 // The graph compiling of mindRT.
733 if ((backend == kMsConvert) && context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
734 TaskEmitActionForMindRT(res);
735 return true;
736 }
737
738 // The graph compiling of control sink.
739 if (IsCtrlSink() && backend == kMsConvert) {
740 res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph));
741 return true;
742 }
743 std::vector<PrimitivePtr> cut_list = compile::nonlinear_ops;
744 if (bc_ptr->name() == kMsConvert) {
745 cut_list = compile::GetMsNonlinearOps();
746 }
747 std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
748 res->results()[kOutput] = compile->CompileAndLink(func_graph);
749 return true;
750 }
751
ExecuteAction(const ResourcePtr & res)752 bool ExecuteAction(const ResourcePtr &res) {
753 MS_EXCEPTION_IF_NULL(res);
754 if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
755 CheckGraphOutputConstOrParameter(res->func_graph())) {
756 return true;
757 }
758 if (res->results().count(kOutput) == 0) {
759 MS_LOG(EXCEPTION) << "Execute args error";
760 }
761 std::string backend = MsContext::GetInstance()->backend_policy();
762 // The graph running of mindRT.
763 if ((backend == kMsConvert) && MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
764 ExecuteActionForMindRT(res);
765 return true;
766 }
767
768 // The graph running of control sink.
769 if (IsCtrlSink() && backend == kMsConvert) {
770 if (!res->results()[kOutput].is<GraphId>()) {
771 MS_LOG(EXCEPTION) << "Execute args error";
772 }
773 auto graph_id = res->results()[kOutput].cast<GraphId>();
774 std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
775 compile::MsBackend *msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr).get();
776 MS_EXCEPTION_IF_NULL(msbc_ptr);
777 compile::VmEvalFuncPtr run =
778 std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef {
779 MS_LOG(INFO) << "Execute args size " << args.size();
780 auto outs = msbc_ptr->RunGraph(graph_id, args);
781 MS_LOG(DEBUG) << "out size " << outs.size();
782 return outs[0];
783 });
784 res->results()[kOutput] = run;
785 return true;
786 }
787
788 if (!res->results()[kOutput].is<compile::FinalVMPtr>()) {
789 MS_LOG(EXCEPTION) << "Execute args error";
790 }
791 compile::FinalVMPtr vm = res->results()[kOutput].cast<compile::FinalVMPtr>();
792 if (vm == nullptr) {
793 MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM";
794 return true;
795 }
796 compile::VmEvalFuncPtr run =
797 std::make_shared<compile::VmEvalFunc>(std::bind(&compile::FinalVM::Eval, vm, std::placeholders::_1));
798 res->results()[kOutput] = run;
799 return true;
800 }
801
802 #if ((defined ENABLE_CPU) && (!defined _WIN32))
StartPSWorkerAction(const ResourcePtr &)803 bool StartPSWorkerAction(const ResourcePtr &) {
804 ps::Worker::GetInstance().Run();
805 return true;
806 }
StartFLWorkerAction(const ResourcePtr &)807 bool StartFLWorkerAction(const ResourcePtr &) {
808 fl::worker::FLWorker::GetInstance().Run();
809 return true;
810 }
811
StartPSServerAction(const ResourcePtr & res)812 bool StartPSServerAction(const ResourcePtr &res) {
813 MS_EXCEPTION_IF_NULL(res);
814 FuncGraphPtr func_graph = res->func_graph();
815 auto &ps = ps::ParameterServer::GetInstance();
816 ps.Run(func_graph);
817 return true;
818 }
819
StartServerAction(const ResourcePtr & res)820 bool StartServerAction(const ResourcePtr &res) {
821 MS_EXCEPTION_IF_NULL(res);
822 FuncGraphPtr func_graph = res->func_graph();
823 const std::string &server_mode_ = ps::PSContext::instance()->server_mode();
824 uint32_t worker_num = ps::PSContext::instance()->initial_worker_num();
825 uint32_t server_num = ps::PSContext::instance()->initial_server_num();
826 uint16_t fl_server_port = ps::PSContext::instance()->fl_server_port();
827
828 // Update model threshold is a certain ratio of start_fl_job threshold.
829 // update_model_threshold = start_fl_job_threshold * update_model_ratio.
830 size_t start_fl_job_threshold = ps::PSContext::instance()->start_fl_job_threshold();
831 float update_model_ratio = ps::PSContext::instance()->update_model_ratio();
832 size_t update_model_threshold = static_cast<size_t>(std::ceil(start_fl_job_threshold * update_model_ratio));
833 uint64_t start_fl_job_time_window = ps::PSContext::instance()->start_fl_job_time_window();
834 uint64_t update_model_time_window = ps::PSContext::instance()->update_model_time_window();
835
836 std::vector<fl::server::RoundConfig> rounds_config = {
837 {"startFLJob", true, start_fl_job_time_window, true, start_fl_job_threshold},
838 {"updateModel", true, update_model_time_window, true, update_model_threshold},
839 {"getModel"},
840 {"pullWeight"},
841 {"pushWeight", false, 3000, true, server_num, true},
842 {"pushMetrics", false, 3000, true, 1}};
843
844 float share_secrets_ratio = ps::PSContext::instance()->share_secrets_ratio();
845 uint64_t cipher_time_window = ps::PSContext::instance()->cipher_time_window();
846 size_t reconstruct_secrets_threshold = ps::PSContext::instance()->reconstruct_secrets_threshold() + 1;
847
848 size_t exchange_keys_threshold =
849 std::max(static_cast<size_t>(std::ceil(start_fl_job_threshold * share_secrets_ratio)), update_model_threshold);
850 size_t get_keys_threshold =
851 std::max(static_cast<size_t>(std::ceil(exchange_keys_threshold * share_secrets_ratio)), update_model_threshold);
852 size_t share_secrets_threshold =
853 std::max(static_cast<size_t>(std::ceil(get_keys_threshold * share_secrets_ratio)), update_model_threshold);
854 size_t get_secrets_threshold =
855 std::max(static_cast<size_t>(std::ceil(share_secrets_threshold * share_secrets_ratio)), update_model_threshold);
856 size_t client_list_threshold = std::max(static_cast<size_t>(std::ceil(update_model_threshold * share_secrets_ratio)),
857 reconstruct_secrets_threshold);
858 #ifdef ENABLE_ARMOUR
859 std::string encrypt_type = ps::PSContext::instance()->encrypt_type();
860 if (encrypt_type == ps::kPWEncryptType) {
861 MS_LOG(INFO) << "Add secure aggregation rounds.";
862 rounds_config.push_back({"exchangeKeys", true, cipher_time_window, true, exchange_keys_threshold});
863 rounds_config.push_back({"getKeys", true, cipher_time_window, true, get_keys_threshold});
864 rounds_config.push_back({"shareSecrets", true, cipher_time_window, true, share_secrets_threshold});
865 rounds_config.push_back({"getSecrets", true, cipher_time_window, true, get_secrets_threshold});
866 rounds_config.push_back({"getClientList", true, cipher_time_window, true, client_list_threshold});
867 rounds_config.push_back({"reconstructSecrets", true, cipher_time_window, true, reconstruct_secrets_threshold});
868 }
869 #endif
870 fl::server::CipherConfig cipher_config = {
871 share_secrets_ratio, cipher_time_window, exchange_keys_threshold, get_keys_threshold,
872 share_secrets_threshold, get_secrets_threshold, client_list_threshold, reconstruct_secrets_threshold};
873
874 size_t executor_threshold = 0;
875 if (server_mode_ == ps::kServerModeFL || server_mode_ == ps::kServerModeHybrid) {
876 executor_threshold = update_model_threshold;
877 fl::server::Server::GetInstance().Initialize(true, true, fl_server_port, rounds_config, cipher_config, func_graph,
878 executor_threshold);
879 } else if (server_mode_ == ps::kServerModePS) {
880 executor_threshold = worker_num;
881 fl::server::Server::GetInstance().Initialize(true, false, 0, rounds_config, cipher_config, func_graph,
882 executor_threshold);
883 } else {
884 MS_LOG(EXCEPTION) << "Server mode " << server_mode_ << " is not supported.";
885 return false;
886 }
887 fl::server::Server::GetInstance().Run();
888 return true;
889 }
890
StartPSSchedulerAction(const ResourcePtr &)891 bool StartPSSchedulerAction(const ResourcePtr &) {
892 ps::Scheduler::GetInstance().Run();
893 return true;
894 }
895 #endif
896
897 // The parallel primitive related valuenode might be partitioned so that its value changes by device,
898 // that will result in a synchronization error due to different executing order.
899 // Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive,
900 // the final solution will be proposed later as a parallel feature.
KeepValueNodeDuplication(const AnfNodePtr & value_node,const ResourcePtr & res)901 bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) {
902 MS_EXCEPTION_IF_NULL(res);
903 MS_EXCEPTION_IF_NULL(res->manager());
904 auto &node_users = res->manager()->node_users();
905 auto &users = node_users[value_node];
906 auto used_by_keep_value_prim =
907 std::any_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
908 MS_EXCEPTION_IF_NULL(user.first);
909 auto cnode = user.first->cast<CNodePtr>();
910 if (cnode == nullptr) {
911 return false;
912 }
913 auto prim_node = cnode->input(0);
914 if (IsValueNode<Primitive>(prim_node)) {
915 auto prim = GetValue<PrimitivePtr>(prim_node->cast<ValueNodePtr>()->value());
916 MS_EXCEPTION_IF_NULL(prim);
917 // value_node is referenced by some parallel primitive
918 return prim->HasAttr("keep_value_node_input");
919 }
920 return false;
921 });
922 return used_by_keep_value_prim;
923 }
924
RemoveValueNodeDuplicationsAction(const ResourcePtr & res)925 bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
926 MS_EXCEPTION_IF_NULL(res);
927 FuncGraphPtr func_graph = res->func_graph();
928 if (func_graph == nullptr) {
929 MS_LOG(EXCEPTION) << "Remove value node duplications error.";
930 }
931 auto manager = res->manager();
932 // Remove duplicated value nodes, due to replace operation, can't use reference.
933 auto value_nodes = func_graph->value_nodes();
934 HashCache hash_cache;
935 HashValue hashes;
936 for (const auto &value_pair : value_nodes) {
937 if (KeepValueNodeDuplication(value_pair.first, res)) {
938 continue;
939 }
940 TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
941 }
942 return true;
943 }
944
PipelineSplitAction(const ResourcePtr & res)945 bool PipelineSplitAction(const ResourcePtr &res) { return PipelineSplitPass(res); }
ValidateAction(const ResourcePtr & res)946 bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
947
SetMindIRGraphAction(const ResourcePtr & res)948 bool SetMindIRGraphAction(const ResourcePtr &res) {
949 MS_EXCEPTION_IF_NULL(res);
950 res->set_is_load(true);
951 auto cell = py::cast<CellPtr>(res->source_input());
952 if (cell == nullptr) {
953 MS_LOG(EXCEPTION) << "The graph loaded from mindir is null.";
954 }
955 const std::string mindir_graph = "graph_load_from_mindir";
956 auto obj = cell->GetAttr(mindir_graph);
957 if (obj == nullptr) {
958 MS_LOG(EXCEPTION) << "The graph loaded from mindir is null. The cell has not attribute: " << mindir_graph;
959 }
960 auto fg = GetValue<FuncGraphPtr>(obj);
961 if (fg == nullptr) {
962 MS_LOG(EXCEPTION) << "The graph loaded from mindir is null.";
963 }
964 res->set_func_graph(fg);
965 FuncGraphManagerPtr mng = fg->manager();
966 if (mng == nullptr) {
967 auto res_mng = res->manager();
968 MS_EXCEPTION_IF_NULL(res_mng);
969 res_mng->AddFuncGraph(fg);
970 fg->set_manager(res_mng);
971 }
972 abstract::AbstractBasePtrList broaded_args;
973 const auto &args_spec_list = res->args_spec();
974 (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broaded_args),
975 [](const AbstractBasePtr &arg) -> AbstractBasePtr {
976 MS_EXCEPTION_IF_NULL(arg);
977 if (arg->GetValueTrack() != kAnyValue) {
978 return arg->Broaden();
979 }
980 return arg;
981 });
982
983 abstract::AbstractBasePtrList func_args;
984 const auto inputs = fg->get_inputs();
985 (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(func_args),
986 [](const AnfNodePtr &arg) -> AbstractBasePtr {
987 MS_EXCEPTION_IF_NULL(arg);
988 return arg->abstract()->Broaden();
989 });
990 if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) {
991 MS_LOG(EXCEPTION) << "The input arguments is not compatible with the function graph which has been exported before."
992 << " Please check the args is same with export.\n"
993 << "Export input args info:" << abstract::ArgsToString(func_args) << "\n"
994 << "The input args info:" << abstract::ArgsToString(broaded_args);
995 }
996
997 // suppose that there is not KeywordArgument for the top graph
998 // get the hyper parameter
999 for (const auto ¶m : fg->parameters()) {
1000 auto param_node = std::static_pointer_cast<Parameter>(param);
1001 MS_EXCEPTION_IF_NULL(param_node);
1002 if (param_node->has_default()) {
1003 auto value = param_node->default_param();
1004 MS_EXCEPTION_IF_NULL(value);
1005 auto abs_value = value->ToAbstract()->cast<abstract::AbstractTensorPtr>();
1006 auto ref_key = std::make_shared<RefKey>(param_node->name());
1007 auto abs_ref_key = ref_key->ToAbstract();
1008 auto abs_ref = std::make_shared<abstract::AbstractRef>(abs_ref_key, abs_value);
1009 broaded_args.push_back(abs_ref);
1010 }
1011 }
1012 (void)AbstractAnalyze(res, res->func_graph(), broaded_args, true);
1013 auto it = abstract::AnalysisResultCacheMgr::GetInstance().begin();
1014 auto it_end = abstract::AnalysisResultCacheMgr::GetInstance().end();
1015 for (; it != it_end; ++it) {
1016 it->first->node()->set_abstract(it->second->abstract());
1017 }
1018 abstract::AnalysisResultCacheMgr::GetInstance().Clear();
1019 return true;
1020 }
1021
ActionPyStub(const ResourcePtr & res,opt::python_pass::Phase phase)1022 bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
1023 MS_EXCEPTION_IF_NULL(res->manager());
1024 MS_EXCEPTION_IF_NULL(res->func_graph());
1025 auto ppm = opt::python_pass::PyPassManager::GetInstance();
1026 ppm->SetResource(res);
1027 return ppm->GetPassGroup(phase)->Run(res->func_graph());
1028 }
1029
PreAdActionPyStub(const ResourcePtr & res)1030 bool PreAdActionPyStub(const ResourcePtr &res) {
1031 if (!ActionPyStub(res, opt::python_pass::Phase::PREAD)) {
1032 MS_LOG(DEBUG) << "No Match.";
1033 }
1034 return true;
1035 }
1036
OptActionVmPyStub(const ResourcePtr & res)1037 bool OptActionVmPyStub(const ResourcePtr &res) {
1038 if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
1039 if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
1040 // Renomalize
1041 FuncGraphPtr func_graph = res->func_graph();
1042 MS_EXCEPTION_IF_NULL(func_graph);
1043 abstract::AbstractBasePtrList args_spec;
1044 auto parameters = func_graph->parameters();
1045 (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
1046 [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
1047 FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
1048 res->set_func_graph(new_fg);
1049 res->set_args_spec(args_spec);
1050 }
1051 if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
1052 return VmOptimizeAction(res);
1053 }
1054 }
1055 return true;
1056 }
1057
OptActionGePyStub(const ResourcePtr & res)1058 bool OptActionGePyStub(const ResourcePtr &res) {
1059 if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
1060 if (opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
1061 // Renomalize
1062 FuncGraphPtr func_graph = res->func_graph();
1063 MS_EXCEPTION_IF_NULL(func_graph);
1064 abstract::AbstractBasePtrList args_spec;
1065 auto parameters = func_graph->parameters();
1066 (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
1067 [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
1068 FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
1069 res->set_func_graph(new_fg);
1070 res->set_args_spec(args_spec);
1071 }
1072 if (opt::python_pass::PyPassManager::GetInstance()->ShouldReOpt()) {
1073 return GeOptimizeAction(res);
1074 }
1075 }
1076 return true;
1077 }
1078
CommonPipeline()1079 static std::vector<ActionItem> CommonPipeline() {
1080 std::vector<ActionItem> actions;
1081
1082 // Parse the python ast to ANF graph
1083 (void)actions.emplace_back(std::make_pair("parse", ParseAction));
1084
1085 // Resolve the python func
1086 (void)actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction));
1087
1088 auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs();
1089 if (!multi_graphs) {
1090 (void)actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
1091 }
1092
1093 (void)actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
1094 // Evaluate type and shape, and specialize
1095 (void)actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
1096 // Auto-monad for side-effects handling.
1097 (void)actions.emplace_back(std::make_pair("auto_monad", AutoMonadAction));
1098 // Do data structure simplifications and inline
1099 (void)actions.emplace_back(std::make_pair("inline", OptInlineAction));
1100 // Add pre-ad, post-inline python pass stub
1101 (void)actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub));
1102 // Do PipelineSplit
1103 (void)actions.emplace_back(std::make_pair("pipeline_split", PipelineSplitAction));
1104
1105 return actions;
1106 }
1107
GePipeline()1108 std::vector<ActionItem> GePipeline() {
1109 auto actions = CommonPipeline();
1110 // optimize
1111 (void)actions.emplace_back(std::make_pair("optimize", GeOptimizeAction));
1112 // Add opt-stage python pass stub
1113 (void)actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub));
1114 (void)actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
1115 (void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
1116 (void)actions.emplace_back(std::make_pair("validate", ValidateAction));
1117 return actions;
1118 }
1119
VmPipeline()1120 std::vector<ActionItem> VmPipeline() {
1121 auto actions = CommonPipeline();
1122
1123 // optimize
1124 (void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
1125
1126 // Add opt-stage python pass stub
1127 (void)actions.emplace_back(std::make_pair("py_opt", OptActionVmPyStub));
1128
1129 (void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
1130
1131 // eliminate forward cnode for grad graph
1132 (void)actions.emplace_back(std::make_pair("eliminate_forward_cnode", EliminateForwardCNode));
1133
1134 (void)actions.emplace_back(std::make_pair("validate", ValidateAction));
1135 #if ((defined ENABLE_CPU) && (!defined _WIN32))
1136 if (ps::PSContext::instance()->is_worker()) {
1137 std::string server_mode = ps::PSContext::instance()->server_mode();
1138 if (server_mode == ps::kServerModeFL || server_mode == ps::kServerModeHybrid) {
1139 (void)actions.emplace_back(std::make_pair("worker", StartFLWorkerAction));
1140 } else {
1141 (void)actions.emplace_back(std::make_pair("worker", StartPSWorkerAction));
1142 }
1143 }
1144 #endif
1145 // compile the ANF graph
1146 (void)actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
1147
1148 // to execute the graph
1149 (void)actions.emplace_back(std::make_pair("execute", ExecuteAction));
1150
1151 return actions;
1152 }
1153
BackendPipeline()1154 std::vector<ActionItem> BackendPipeline() {
1155 std::vector<ActionItem> actions;
1156 // compile the ANF graph
1157 (void)actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
1158 // to execute the graph
1159 (void)actions.emplace_back(std::make_pair("execute", ExecuteAction));
1160 return actions;
1161 }
MindIRPipeline()1162 std::vector<ActionItem> MindIRPipeline() {
1163 std::vector<ActionItem> actions;
1164 // Set funcGraph loaded from MindIR to resource.
1165 (void)actions.emplace_back(std::make_pair("load_mindir", SetMindIRGraphAction));
1166 (void)actions.emplace_back(std::make_pair("validate", ValidateAction));
1167 // compile the ANF graph
1168 (void)actions.emplace_back(std::make_pair("task_emit", TaskEmitAction));
1169 // to execute the graph
1170 (void)actions.emplace_back(std::make_pair("execute", ExecuteAction));
1171 return actions;
1172 }
1173 #if ((defined ENABLE_CPU) && (!defined _WIN32))
ServerPipeline()1174 std::vector<ActionItem> ServerPipeline() {
1175 auto actions = CommonPipeline();
1176 (void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
1177 (void)actions.emplace_back(std::make_pair("validate", ValidateAction));
1178 (void)actions.emplace_back(std::make_pair("server", StartServerAction));
1179 return actions;
1180 }
1181
PServerPipeline()1182 std::vector<ActionItem> PServerPipeline() {
1183 auto actions = CommonPipeline();
1184 (void)actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
1185 (void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
1186 (void)actions.emplace_back(std::make_pair("validate", ValidateAction));
1187 (void)actions.emplace_back(std::make_pair("pserver", StartPSServerAction));
1188 return actions;
1189 }
1190
PSchedulerPipeline()1191 std::vector<ActionItem> PSchedulerPipeline() {
1192 std::vector<ActionItem> actions;
1193 (void)actions.emplace_back(std::make_pair("scheduler", StartPSSchedulerAction));
1194 return actions;
1195 }
1196 #endif
1197 } // namespace pipeline
1198 } // namespace mindspore
1199