• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "pipeline/jit/ps/action.h"
18 
19 #include <memory>
20 #include <map>
21 #include <utility>
22 #include <vector>
23 #include <set>
24 #include <string>
25 #include <algorithm>
26 #include <functional>
27 
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "ir/anf.h"
31 #include "ir/func_graph_cloner.h"
32 #include "ir/param_info.h"
33 #include "ir/cell.h"
34 #include "include/common/utils/python_adapter.h"
35 #include "include/common/utils/anfalgo.h"
36 #include "include/common/utils/utils.h"
37 #include "include/common/utils/parallel_context.h"
38 #include "abstract/abstract_value.h"
39 #include "frontend/operator/composite/composite.h"
40 #include "frontend/parallel/step_auto_parallel.h"
41 #include "frontend/parallel/graph_util/graph_splitter.h"
42 #include "frontend/parallel/step_parallel_utils.h"
43 #include "frontend/parallel/shard/shard.h"
44 #include "pipeline/jit/ps/pipeline.h"
45 #include "pipeline/jit/ps/pass.h"
46 #include "pipeline/jit/ps/parse/parse_base.h"
47 #include "pipeline/jit/ps/parse/data_converter.h"
48 #include "pipeline/jit/ps/static_analysis/auto_monad.h"
49 #include "pipeline/jit/ps/static_analysis/order_enforce.h"
50 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
51 #include "pipeline/jit/ps/static_analysis/async_eval_result.h"
52 #include "pipeline/jit/ps/static_analysis/program_specialize.h"
53 #include "pipeline/jit/ps/resource.h"
54 #include "pipeline/jit/ps/remove_value_node_dup.h"
55 #include "pipeline/jit/ps/event_message_print.h"
56 #include "pipeline/pynative/pynative_execute.h"
57 #include "frontend/optimizer/optimizer.h"
58 #include "frontend/optimizer/ad/grad.h"
59 #include "utils/ms_context.h"
60 #include "utils/ms_utils.h"
61 #include "utils/phase.h"
62 #include "utils/compile_config.h"
63 #include "backend/graph_compiler/transform.h"
64 #include "load_mindir/infer_mindir.h"
65 #include "include/backend/debug/data_dump/dump_json_parser.h"
66 #include "backend/common/graph_kernel/graph_kernel_flags.h"
67 #include "include/backend/debug/profiler/profiling.h"
68 #include "frontend/optimizer/fallback_rewriter.h"
69 #include "pipeline/jit/ps/load_mindir.h"
70 #if defined(__linux__) && defined(WITH_BACKEND)
71 #include "include/backend/distributed/cluster/cluster_context.h"
72 #include "include/backend/distributed/ps/ps_context.h"
73 #include "include/backend/distributed/ps/util.h"
74 #endif
75 
76 namespace mindspore {
77 namespace pipeline {
78 namespace {
79 const auto kFirstInput = 1;
80 const auto kSecondInput = 2;
81 const auto kLazyInlineThershold = 64;
82 
ExistControlFlow(const FuncGraphPtr & func_graph)83 bool ExistControlFlow(const FuncGraphPtr &func_graph) {
84   MS_EXCEPTION_IF_NULL(func_graph);
85   return !func_graph->func_graphs_used_total().empty();
86 }
87 
EnableGradForScalar(const abstract::AbstractBasePtr & abs)88 bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
89   MS_EXCEPTION_IF_NULL(abs);
90   return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
91          abs->BuildType()->isa<Number>();
92 }
93 
EnableSequenceBroaden(const abstract::AbstractBasePtr & abs)94 bool EnableSequenceBroaden(const abstract::AbstractBasePtr &abs) {
95   MS_EXCEPTION_IF_NULL(abs);
96   return abs->isa<abstract::AbstractSequence>() &&
97          abs->cast<abstract::AbstractSequencePtr>()->ContainsAllBroadenTensors();
98 }
99 
ContainsAbstractFunction(const abstract::AbstractBasePtr & abs)100 bool ContainsAbstractFunction(const abstract::AbstractBasePtr &abs) {
101   MS_EXCEPTION_IF_NULL(abs);
102   if (abs->isa<abstract::AbstractFunction>()) {
103     return true;
104   }
105   if (abs->isa<abstract::AbstractSequence>()) {
106     const auto &abs_list = abs->cast<abstract::AbstractSequencePtr>()->elements();
107     return std::any_of(abs_list.cbegin(), abs_list.cend(),
108                        [](const auto &elem) { return ContainsAbstractFunction(elem); });
109   }
110   if (abs->isa<abstract::AbstractDictionary>()) {
111     const auto &abs_pair_list = abs->cast<abstract::AbstractDictionaryPtr>()->elements();
112     return std::any_of(abs_pair_list.cbegin(), abs_pair_list.cend(),
113                        [](const auto &pair) { return ContainsAbstractFunction(pair.second); });
114   }
115   return false;
116 }
117 
UpdateFuncGraphParameter(const FuncGraphPtr & func_graph,const std::vector<ValuePtr> & arguments)118 void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph, const std::vector<ValuePtr> &arguments) {
119   MS_EXCEPTION_IF_NULL(func_graph);
120   std::vector<AnfNodePtr> new_paras;
121   for (size_t i = 0; i < func_graph->parameters().size(); ++i) {
122     const auto &param = func_graph->parameters()[i];
123     auto param_node = param->cast<ParameterPtr>();
124     MS_EXCEPTION_IF_NULL(param_node);
125     if (param_node->has_default()) {
126       new_paras.push_back(param_node);
127       continue;
128     }
129 
130     // Handle the Parameter from input arguments.
131     if (i < arguments.size()) {
132       auto param_value = dyn_cast<tensor::MetaTensor>(arguments[i]);
133       if (param_value != nullptr && param_value->is_parameter()) {
134         param_node->set_default_param(param_value);
135       }
136     }
137 
138     AbstractBasePtr param_abs = param_node->abstract();
139     MS_EXCEPTION_IF_NULL(param_abs);
140     if ((param_abs->BuildValue() == kValueAny && !ContainsAbstractFunction(param_abs)) ||
141         EnableGradForScalar(param_abs) || EnableSequenceBroaden(param_abs)) {
142       new_paras.push_back(param_node);
143     } else {
144       MS_LOG(INFO) << "Remove the " << i << "th parameter, since it's passed a constant argument.";
145     }
146   }
147   func_graph->set_parameters(new_paras);
148 }
149 
150 // Exist ScalarAdd ScalarSub etc OPS which will backoff to CPU
IsNeedBackoffGraph(const FuncGraphPtr & func_graph)151 bool IsNeedBackoffGraph(const FuncGraphPtr &func_graph) {
152   MS_EXCEPTION_IF_NULL(func_graph);
153   std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return(), SuccDeeperSimple);
154   return std::any_of(node_list.begin(), node_list.end(),
155                      [](const AnfNodePtr &node) { return common::AnfAlgo::IsNodeMutableScalar(node); });
156 }
157 
158 // Disable mindRT in the heterogeneous scenario + dynamic_shape scenario.
DisableMindRT(const ResourcePtr & resource)159 void DisableMindRT(const ResourcePtr &resource) {
160   MS_EXCEPTION_IF_NULL(resource);
161   auto context_ptr = MsContext::GetInstance();
162   MS_EXCEPTION_IF_NULL(context_ptr);
163   if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
164     return;
165   }
166 #if defined(__linux__) && defined(WITH_BACKEND)
167   if (ps::PSContext::instance()->cache_enable()) {
168     return;
169   }
170 #endif
171 }
172 
TaskEmitActionForMindRT(const ResourcePtr & resource)173 void TaskEmitActionForMindRT(const ResourcePtr &resource) {
174   MS_EXCEPTION_IF_NULL(resource);
175   // Get the mindRT backend.
176   auto bc_ptr = resource->GetBackend();
177   // In pyexecute kernel, the input data would be stored in user data which is a python object, this converter
178   // is used to convert user data to device ptr in device address.
179   compile::set_pydata_converter([](const py::object &obj, ValuePtr *value) { return parse::ConvertData(obj, value); });
180   auto mindrt_bc_ptr = std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr);
181   MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
182   MS_EXCEPTION_IF_NULL(resource->func_graph());
183   auto actor_info = mindrt_bc_ptr->CompileGraphs(resource->func_graph());
184   resource->SetResult(kOutput, actor_info);
185   resource->SetResult(kActorInfo, actor_info);
186 }
187 
ExecuteActionForMindRT(const ResourcePtr & resource)188 void ExecuteActionForMindRT(const ResourcePtr &resource) {
189   MS_EXCEPTION_IF_NULL(resource);
190   const auto actor_info = resource->GetResult(kOutput).cast<compile::ActorInfo>();
191   // Get the mindRT backend.
192   auto bc_ptr = resource->GetBackend();
193   auto mindrt_bc_ptr = (std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr)).get();
194   MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
195 
196   // Construct the graph run function ptr.
197   compile::VmEvalFuncPtr run =
198     std::make_shared<compile::VmEvalFunc>([mindrt_bc_ptr, actor_info](const VectorRef &args) -> BaseRef {
199       MS_LOG(DEBUG) << "Execute args size " << args.size();
200       VectorRef outputs;
201       mindrt_bc_ptr->RunGraph(actor_info, args, &outputs);
202       MS_LOG(DEBUG) << "out size " << outputs.size();
203       if (outputs.empty()) {
204         return VectorRef();
205       } else {
206         return outputs[0];
207       }
208     });
209   resource->SetResult(kOutput, run);
210 }
211 
ConstructGraphForEval(const ValuePtr & func,const abstract::AbstractBasePtrList & args_abs)212 FuncGraphPtr ConstructGraphForEval(const ValuePtr &func, const abstract::AbstractBasePtrList &args_abs) {
213   auto func_abs = func->ToAbstract();
214   if (!func_abs->isa<abstract::AbstractFunction>()) {
215     MS_LOG(EXCEPTION) << "The value : " << func->ToString() << " is not a callable object.";
216   }
217   // construct a function graph.
218   auto infer_graph = std::make_shared<FuncGraph>();
219   std::vector<AnfNodePtr> inputs = {std::make_shared<ValueNode>(func)};
220   std::transform(args_abs.begin(), args_abs.end(), std::back_inserter(inputs),
221                  [infer_graph](const AbstractBasePtr &) -> AnfNodePtr { return infer_graph->add_parameter(); });
222   auto infer_node = infer_graph->NewCNode(inputs);
223   infer_graph->set_return(infer_node);
224   return infer_graph;
225 }
226 }  // namespace
227 using CompileGraphs = compile::CompileGraphs;
228 using abstract::AnalysisResult;
229 using mindspore::abstract::AnalysisContextPtr;
230 
231 // Whether this process in a MindSpore cluster.
232 static bool is_cluster_initialized = false;
233 
IsDynamicShapeGraph(const FuncGraphPtr & func_graph)234 bool IsDynamicShapeGraph(const FuncGraphPtr &func_graph) {
235   MS_EXCEPTION_IF_NULL(func_graph);
236   std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return(), SuccDeeperSimple);
237   return std::any_of(node_list.begin(), node_list.end(), [](const AnfNodePtr &node) {
238     if (common::AnfAlgo::IsCallNode(node)) {
239       return false;
240     }
241     return common::AnfAlgo::IsDynamicShape(node);
242   });
243 }
244 
AbstractAnalyze(const abstract::AnalysisEnginePtr & engine,const FuncGraphPtr & func_graph,const abstract::AbstractBasePtrList & args_abs,bool is_load_resoure,bool clear)245 abstract::AnalysisResult AbstractAnalyze(const abstract::AnalysisEnginePtr &engine, const FuncGraphPtr &func_graph,
246                                          const abstract::AbstractBasePtrList &args_abs, bool is_load_resoure,
247                                          bool clear) {
248   MS_LOG(DEBUG) << "AbstractAnalyze start";
249   py::gil_scoped_acquire gil;
250   MS_EXCEPTION_IF_NULL(engine);
251   if (clear || is_load_resoure) {
252     auto manager = engine->func_graph_manager();
253     MS_EXCEPTION_IF_NULL(manager);
254     engine->Clear();
255     static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
256     for (auto &node : manager->all_nodes()) {
257       MS_EXCEPTION_IF_NULL(node);
258       // Handle previous inferred value for CNode if is loaded from MindIR
259       // If the primitive is not defined in front end, keep the inferred value loaded from MindIR.
260       if (is_load_resoure) {
261         auto primitive = GetCNodePrimitive(node);
262         if (primitive != nullptr) {
263           auto is_load = primitive->GetAttr("is_load");
264           if (abstract::GetPrimEvaluator(primitive, engine) == nullptr && is_load != nullptr &&
265               GetValue<bool>(is_load)) {
266             MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
267             continue;
268           }
269         }
270         if (!clear && node->isa<Parameter>()) {
271           continue;
272         }
273       }
274 
275       const AbstractBasePtr &prev_inferred = node->abstract();
276       // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
277       if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
278         // Reset tuple/list abstract use flags.
279         if (enable_eliminate_unused_element && prev_inferred != nullptr &&
280             prev_inferred->isa<abstract::AbstractSequence>()) {
281           SetSequenceNodeElementsUseFlags(node, nullptr);
282         }
283         node->set_abstract(nullptr);
284         MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
285       }
286     }
287   }
288   auto res = engine->Run(func_graph, args_abs);
289   MS_LOG(INFO) << "function call depth: " << abstract::FunctionCallDepth()
290                << ", simulate call depth: " << abstract::StackFrameDepth();
291   MS_LOG(DEBUG) << "AbstractAnalyze end";
292   return res;
293 }
294 
AbstractAnalyze(const ValuePtr & func,const abstract::AbstractBasePtrList & args_abs,bool clear)295 abstract::AnalysisResult AbstractAnalyze(const ValuePtr &func, const abstract::AbstractBasePtrList &args_abs,
296                                          bool clear) {
297   auto infer_graph = func->isa<FuncGraph>() ? func->cast<FuncGraphPtr>() : ConstructGraphForEval(func, args_abs);
298   auto manager = Manage(infer_graph, true);
299   auto engine = std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager);
300   return AbstractAnalyze(engine, infer_graph, args_abs, false, clear);
301 }
302 
AbstractAnalyzeWithResourceClean(const ValuePtr & func,const abstract::AbstractBasePtrList & args_abs)303 abstract::AnalysisResult AbstractAnalyzeWithResourceClean(const ValuePtr &func,
304                                                           const abstract::AbstractBasePtrList &args_abs) {
305   auto infer_graph = func->isa<FuncGraph>() ? func->cast<FuncGraphPtr>() : ConstructGraphForEval(func, args_abs);
306 
307   ResourcePtr resource = std::make_shared<Resource>();
308   resource->set_func_graph(infer_graph);
309 
310   auto engine = resource->engine();
311   auto res = AbstractAnalyze(engine, infer_graph, args_abs, false, true);
312 
313   GraphExecutorPy::GetInstance()->CleanCompileRes(resource);
314   return res;
315 }
316 
ProgramSpecialize(const abstract::AnalysisEnginePtr & engine,const FuncGraphPtr & func_graph,const abstract::AnalysisContextPtr & context)317 FuncGraphPtr ProgramSpecialize(const abstract::AnalysisEnginePtr &engine, const FuncGraphPtr &func_graph,
318                                const abstract::AnalysisContextPtr &context) {
319   MS_EXCEPTION_IF_NULL(engine);
320   MS_LOG(DEBUG) << "ProgramSpecialize start";
321   abstract::ProgramSpecializer specializer(engine);
322   FuncGraphPtr result = specializer.Run(func_graph, context);
323   auto manager = engine->func_graph_manager();
324   MS_EXCEPTION_IF_NULL(manager);
325   manager->KeepRoots({result});
326   specializer.SpecializeCNodeInput0FuncGraph();
327   MS_LOG(DEBUG) << "ProgramSpecialize end";
328   return result;
329 }
330 
Renormalize(const ResourcePtr & resource,const FuncGraphPtr & func_graph,const abstract::AbstractBasePtrList & args_abs)331 FuncGraphPtr Renormalize(const ResourcePtr &resource, const FuncGraphPtr &func_graph,
332                          const abstract::AbstractBasePtrList &args_abs) {
333   MS_EXCEPTION_IF_NULL(resource);
334   MS_LOG(DEBUG) << "Renormalize start";
335   auto engine = resource->engine();
336 
337   abstract::AnalysisResult result;
338   {
339     MsProfileStatGuard stat_guard("renormalize.infer");
340     result = AbstractAnalyze(engine, func_graph, args_abs, resource->is_load(), true);
341   }
342   FuncGraphPtr res;
343   {
344     MsProfileStatGuard stat_guard("renormalize.specialize");
345     res = ProgramSpecialize(engine, func_graph, result.context);
346     resource->set_func_graph(res);
347   }
348 
349   MS_LOG(DEBUG) << "Renormalize end";
350   return res;
351 }
352 
Renormalize(const ValuePtr & func,const abstract::AbstractBasePtrList & args_abs)353 FuncGraphPtr Renormalize(const ValuePtr &func, const abstract::AbstractBasePtrList &args_abs) {
354   auto func_abs = func->ToAbstract();
355   if (!func_abs->isa<abstract::AbstractFunction>()) {
356     MS_LOG(EXCEPTION) << "The value: " << func->ToString() << " is not a callable object.";
357   }
358   auto func_graph = ConstructGraphForEval(func, args_abs);
359   auto manager = Manage(func_graph, true);
360   auto engine = std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager);
361 
362   abstract::AnalysisResult result;
363   {
364     MsProfileStatGuard stat_guard("renormalize.infer");
365     result = AbstractAnalyze(engine, func_graph, args_abs, false);
366   }
367   FuncGraphPtr res;
368   {
369     MsProfileStatGuard stat_guard("renormalize.specialize");
370     res = ProgramSpecialize(engine, func_graph, result.context);
371   }
372 
373   return res;
374 }
375 
SetMindIRLoadFlag(const ResourcePtr & resource)376 void SetMindIRLoadFlag(const ResourcePtr &resource) {
377   MS_EXCEPTION_IF_NULL(resource);
378   auto manager = resource->manager();
379   MS_EXCEPTION_IF_NULL(manager);
380   FuncGraphPtr loaded_graph = nullptr;
381   size_t loaded_graph_num = 0;
382   auto all_graphs = manager->func_graphs();
383   for (auto &graph : all_graphs) {
384     MS_EXCEPTION_IF_NULL(graph);
385     if (graph->has_attr("is_load")) {
386       loaded_graph = graph;
387       loaded_graph_num += 1;
388       resource->set_is_load(true);
389       return;
390     }
391   }
392 }
393 
394 namespace {
395 // Get entry function/class.method name.
GetFunctionName(const py::object & input)396 std::string GetFunctionName(const py::object &input) {
397   // Get Cell.construct() or @jit function name.
398   std::string function_name;
399   if (py::hasattr(input, parse::PYTHON_PARSE_METHOD)) {
400     // The class type string format is like: <class 'x.x.xxx'>
401     std::string class_type_name = py::cast<std::string>(py::str(input.get_type()));
402     constexpr auto class_type_prefix_len = 8;  // <class '
403     constexpr auto class_type_suffix_len = 2;  // '>
404     const auto class_type_len = class_type_name.length();
405     // Exclude class prefix and suffix.
406     auto class_name =
407       class_type_name.substr(class_type_prefix_len, class_type_len - class_type_prefix_len - class_type_suffix_len);
408     auto method_name = py::cast<std::string>(input.attr(parse::PYTHON_PARSE_METHOD));
409     function_name = class_name + '.' + method_name;
410   } else if (py::hasattr(input, "__jit_function__") && py::hasattr(input, "__name__")) {
411     // Get @jit decorated function name.
412     auto jit_name = py::cast<std::string>(input.attr("__name__"));
413     function_name = jit_name;
414   } else {
415     MS_EXCEPTION(NotSupportError) << "Entry Python object for JIT is invalid.\ninput: " << py::str(input);
416   }
417   MS_LOG(DEBUG) << "function_name: " << function_name;
418   return function_name;
419 }
420 
421 // Update top graph name.
UpdateTopGraphDebugInfo(const FuncGraphPtr & func_graph,const py::object & input)422 void UpdateTopGraphDebugInfo(const FuncGraphPtr &func_graph, const py::object &input) {
423   auto function_name = GetFunctionName(input);
424   // Normalize the name.
425   std::replace(function_name.begin(), function_name.end(), '.', '_');
426   std::replace(function_name.begin(), function_name.end(), '<', '_');
427   std::replace(function_name.begin(), function_name.end(), '>', '_');
428 
429   MS_EXCEPTION_IF_NULL(func_graph);
430   MS_EXCEPTION_IF_NULL(func_graph->debug_info());
431   func_graph->debug_info()->set_name(function_name);
432 }
433 
434 struct FuncArgSpec {
435   AnfNodePtrList args_;
436   ParameterPtr varargs_{nullptr};
437   AnfNodePtrList kwonlyargs_;
438   ParameterPtr varkw_{nullptr};
439 };
440 
MakeDefaultValue(const py::dict & defaults,const std::string & arg_name,std::vector<std::string> * namelist_for_default_value,std::vector<AnfNodePtr> * default_values)441 void MakeDefaultValue(const py::dict &defaults, const std::string &arg_name,
442                       std::vector<std::string> *namelist_for_default_value, std::vector<AnfNodePtr> *default_values) {
443   (void)namelist_for_default_value->emplace_back(arg_name);
444   if (defaults.contains(arg_name)) {
445     AnfNodePtr arg_node = NewValueNode(parse::data_converter::PyDataToValue(defaults[py::str(arg_name)]));
446     (void)default_values->emplace_back(arg_node);
447   } else {
448     (void)default_values->emplace_back(NewValueNode(kNull));
449   }
450 }
451 
CheckIgnoreSelfParam(const py::object & input)452 bool CheckIgnoreSelfParam(const py::object &input) {
453   auto input_type = parse::data_converter::GetObjType(input);
454   if (input_type == parse::ResolveType::RESOLVE_TYPE_CLASS_INSTANCE) {
455     return true;
456   }
457   if (input_type == parse::ResolveType::RESOLVE_TYPE_METHOD) {
458     py::object method_object = python_adapter::GetPyObjAttr(input, parse::PYTHON_GET_METHOD_SELF_CLASS);
459     if (!py::isinstance<py::none>(method_object)) {
460       return true;
461     }
462   }
463   return false;
464 }
465 
GetFuncArgSpec(const FuncGraphPtr & func_graph,const py::object & input)466 FuncArgSpec GetFuncArgSpec(const FuncGraphPtr &func_graph, const py::object &input) {
467   auto func = input;
468   if (py::hasattr(input, parse::PYTHON_PARSE_METHOD)) {
469     auto func_name = py::cast<std::string>(input.attr(parse::PYTHON_PARSE_METHOD));
470     func = input.attr(func_name.c_str());
471   }
472   py::tuple obj_tuple =
473     python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, "get_arg_spec_and_default_values", func);
474   auto full_arg_spec = obj_tuple[0];
475   py::dict defaults = obj_tuple[1];
476   std::vector<std::string> namelist_for_default_value;
477   std::vector<AnfNodePtr> default_values;
478   FuncArgSpec arg_spec;
479   bool ignore_self_param = CheckIgnoreSelfParam(input);
480   if (py::hasattr(full_arg_spec, "args")) {
481     for (const auto &arg : full_arg_spec.attr("args")) {
482       auto arg_name = py::cast<std::string>(arg);
483       if (arg_name == "self" && ignore_self_param) {
484         continue;
485       }
486       auto para = func_graph->add_parameter();
487       para->set_is_top_graph_param(true);
488       para->set_name(arg_name);
489       (void)arg_spec.args_.emplace_back(para);
490       MakeDefaultValue(defaults, arg_name, &namelist_for_default_value, &default_values);
491     }
492   }
493 
494   if (py::hasattr(full_arg_spec, "varargs")) {
495     auto varargs = full_arg_spec.attr("varargs");
496     if (!py::isinstance<py::none>(varargs)) {
497       arg_spec.varargs_ = func_graph->add_parameter();
498       arg_spec.varargs_->set_is_top_graph_param(true);
499       auto arg_name = py::cast<std::string>(varargs);
500       arg_spec.varargs_->set_name(arg_name);
501       func_graph->set_has_vararg(true);
502       MakeDefaultValue(defaults, arg_name, &namelist_for_default_value, &default_values);
503     }
504   }
505 
506   if (py::hasattr(full_arg_spec, "kwonlyargs")) {
507     for (const auto &arg : full_arg_spec.attr("kwonlyargs")) {
508       auto para = func_graph->add_parameter();
509       para->set_is_top_graph_param(true);
510       auto arg_name = py::cast<std::string>(arg);
511       para->set_name(arg_name);
512       (void)arg_spec.kwonlyargs_.emplace_back(para);
513       MakeDefaultValue(defaults, arg_name, &namelist_for_default_value, &default_values);
514     }
515     func_graph->set_kwonlyargs_count(SizeToInt(arg_spec.kwonlyargs_.size()));
516   }
517 
518   if (py::hasattr(full_arg_spec, "varkw")) {
519     auto varkw = full_arg_spec.attr("varkw");
520     if (!py::isinstance<py::none>(varkw)) {
521       arg_spec.varkw_ = func_graph->add_parameter();
522       arg_spec.varkw_->set_is_top_graph_param(true);
523       auto arg_name = py::cast<std::string>(varkw);
524       arg_spec.varkw_->set_name(arg_name);
525       func_graph->set_has_kwarg(true);
526       MakeDefaultValue(defaults, arg_name, &namelist_for_default_value, &default_values);
527     }
528   }
529   func_graph->SetDefaultValues(namelist_for_default_value, default_values);
530   return arg_spec;
531 }
532 
BuildTopGraph(const FuncGraphPtr & func_graph,const py::object & input,const abstract::AbstractBasePtrList & args_abs)533 void BuildTopGraph(const FuncGraphPtr &func_graph, const py::object &input,
534                    const abstract::AbstractBasePtrList &args_abs) {
535   // Make Resolve for user top graph 'input'.
536   auto function_name = GetFunctionName(input);
537   parse::NameSpacePtr name_space =
538     std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_ENTRY, py::str(function_name), input);
539   parse::SymbolPtr symbol = std::make_shared<parse::Symbol>(function_name);
540   MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
541   ValueNodePtr module_node = NewValueNode(name_space);
542   ValueNodePtr symbol_node = NewValueNode(symbol);
543 
544   bool contains_value_any = false;
545   ValuePtrList args_value_list;
546   (void)std::transform(args_abs.cbegin(), args_abs.cend(), std::back_inserter(args_value_list),
547                        [&contains_value_any](const AbstractBasePtr &abs) {
548                          auto res = abs->BuildValue();
549                          if (res->isa<ValueAny>()) {
550                            contains_value_any = true;
551                          }
552                          return res;
553                        });
554   CNodePtr resolve_node;
555   if (contains_value_any) {
556     resolve_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
557   } else {
558     ValueNodePtr args_node = NewValueNode<ValuePtrList>(args_value_list);
559     resolve_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node, args_node});
560   }
561 
562   auto arg_spec = GetFuncArgSpec(func_graph, input);
563   bool need_unpack = false;
564   if (func_graph->has_vararg() || func_graph->has_kwarg() || func_graph->kwonlyargs_count() > 0) {
565     need_unpack = true;
566   }
567   // Call user top graph in top graph.
568   AnfNodePtrList inputs;
569   if (!need_unpack) {
570     (void)inputs.emplace_back(resolve_node);
571     std::copy(func_graph->parameters().cbegin(), func_graph->parameters().cend(), std::back_inserter(inputs));
572   } else {
573     (void)inputs.emplace_back(NewValueNode(std::make_shared<prim::UnpackCall>(parse::NAMED_METAGRAPH_UNPACKCALL)));
574     (void)inputs.emplace_back(resolve_node);
575     if (!arg_spec.args_.empty()) {
576       AnfNodePtrList args_inputs = {NewValueNode(prim::kPrimMakeTuple)};
577       std::copy(arg_spec.args_.cbegin(), arg_spec.args_.cend(), std::back_inserter(args_inputs));
578       (void)inputs.emplace_back(func_graph->NewCNodeInOrder(args_inputs));
579     }
580     if (arg_spec.varargs_ != nullptr) {
581       (void)inputs.emplace_back(arg_spec.varargs_);
582     }
583     if (arg_spec.varkw_ != nullptr) {
584       (void)inputs.emplace_back(arg_spec.varkw_);
585     }
586     if (!arg_spec.kwonlyargs_.empty()) {
587       AnfNodePtrList key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
588       AnfNodePtrList value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
589       for (const auto &kwonlyarg : arg_spec.kwonlyargs_) {
590         (void)key_inputs.emplace_back(NewValueNode(kwonlyarg->cast<ParameterPtr>()->name()));
591         (void)value_inputs.emplace_back(kwonlyarg);
592       }
593       auto make_dict =
594         func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNodeInOrder(key_inputs),
595                                      func_graph->NewCNodeInOrder(value_inputs)});
596       (void)inputs.emplace_back(make_dict);
597     }
598   }
599   auto output = func_graph->NewCNodeInOrder(inputs);
600   constexpr auto recursive_level = 2;
601   MS_LOG(DEBUG) << "output: " << output->DebugString(recursive_level);
602   func_graph->set_output(output);
603 }
604 }  // namespace
605 
BootstrapAction(const ResourcePtr & resource)606 bool BootstrapAction(const ResourcePtr &resource) {
607   MS_EXCEPTION_IF_NULL(resource);
608   TraceManager::OpenParserDebugInfoFlag();
609   if (!resource->source_input()) {
610     MS_LOG(INTERNAL_EXCEPTION) << "Bootstrap error";
611   }
612   py::object input = resource->source_input();
613   parse::Parser::InitParserEnvironment(input);
614   parse::Parser::EnableDeferResolve(false);
615   py::module path = py::module::import("os.path");
616   auto dir = path.attr("dirname")(py::globals()["__file__"]).cast<std::string>();
617   python_adapter::set_python_env_flag(true);
618   python_adapter::SetPythonPath(dir);
619 
620   // Create fake top graph firstly.
621   auto top_graph = std::make_shared<FuncGraph>();
622   MS_EXCEPTION_IF_NULL(top_graph);
623   auto is_top_graph = (py::hasattr(input, parse::PYTHON_PARSE_METHOD) || py::hasattr(input, "__jit_function__"));
624   if (!is_top_graph) {
625     MS_EXCEPTION(NotSupportError) << "Not supported Python object for JIT entry.\ninput: " << py::str(input);
626   }
627   UpdateTopGraphDebugInfo(top_graph, input);
628   // Call the user top graph with its arguments.
629   BuildTopGraph(top_graph, input, resource->args_abs());
630   // Set the top graph.
631   parse::Parser::UpdateTopFuncGraph(top_graph);
632   resource->set_func_graph(top_graph);
633   FuncGraphManagerPtr manager = resource->manager();
634   MS_EXCEPTION_IF_NULL(manager);
635   manager->AddFuncGraph(top_graph);
636   return true;
637 }
638 
ParseAction(const ResourcePtr & resource)639 bool ParseAction(const ResourcePtr &resource) {
640   MS_EXCEPTION_IF_NULL(resource);
641   TraceManager::OpenParserDebugInfoFlag();
642   if (!resource->source_input()) {
643     MS_LOG(INTERNAL_EXCEPTION) << "Parse error";
644   }
645 
646   py::object input = resource->source_input();
647   parse::Parser::InitParserEnvironment(input);
648   parse::Parser::EnableDeferResolve(false);
649   py::module path = py::module::import("os.path");
650   auto dir = path.attr("dirname")(py::globals()["__file__"]).cast<std::string>();
651 
652   python_adapter::set_python_env_flag(true);
653   python_adapter::SetPythonPath(dir);
654 
655   ValuePtrList args_value_list;
656   (void)std::transform(resource->args_abs().begin(), resource->args_abs().end(), std::back_inserter(args_value_list),
657                        [](const AbstractBasePtr &abs) { return abs->BuildValue(); });
658   parse::DataConverter data_converter(args_value_list, true);
659   auto converted_ret = data_converter.ConvertData(input);
660   if (converted_ret == nullptr) {
661     MS_LOG(INTERNAL_EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(input));
662   }
663 
664   auto top_graph = converted_ret->cast<FuncGraphPtr>();
665   if (top_graph == nullptr) {
666     MS_LOG(INTERNAL_EXCEPTION) << "Object to parse " << std::string(py::str(input)) << " is not function or cell.";
667   }
668   if (py::hasattr(input, parse::PYTHON_PARSE_METHOD) || py::hasattr(input, "__jit_function__")) {
669     (void)std::for_each(top_graph->parameters().begin(), top_graph->parameters().end(),
670                         [](const AnfNodePtr &param) { param->cast<ParameterPtr>()->set_is_top_graph_param(true); });
671   }
672   parse::Parser::UpdateTopFuncGraph(top_graph);
673   resource->set_func_graph(top_graph);
674   FuncGraphManagerPtr manager = resource->manager();
675   MS_EXCEPTION_IF_NULL(manager);
676   manager->AddFuncGraph(top_graph);
677 
678   parse::Parser::EnableDeferResolve(true);
679   return true;
680 }
681 
682 // obj_map's graphs have the same construct, these graphs can be optimized to one graph.
683 // This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}->
684 // graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx}
685 // all obj_map's graph shared base_graph
CombineLikeGraphs(const ResourcePtr & resource)686 bool CombineLikeGraphs(const ResourcePtr &resource) {
687   MS_EXCEPTION_IF_NULL(resource);
688   auto &obj_map = parse::data_converter::GetObjGraphs();
689   for (auto it = obj_map.rbegin(); it != obj_map.rend(); ++it) {
690     if (it->first.find("lazy_inline") != it->first.npos) {
691       continue;
692     }
693     auto &graphs = it->second;
694     MS_LOG(DEBUG) << "Start combine like graph:" << it->first << ", size:" << graphs.size();
695     auto fg = graphs[0];
696     FuncGraphVector func_graphs = {fg};
697     Cloner cloner(func_graphs, false, false, true, std::make_shared<TraceCopy>(),
698                   std::make_shared<TraceCombileLikeGraphs>());
699     cloner.Run();
700     auto cloned_fg_iter = cloner.cloned_func_graphs().find(fg);
701     if (cloned_fg_iter == cloner.cloned_func_graphs().end()) {
702       MS_LOG(INTERNAL_EXCEPTION) << "Clone func graph failed! " << fg->ToString();
703     }
704     auto base_graph = cloned_fg_iter->second;
705     MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
706 
707     if (fg->parameter_obj_nodes().empty() || graphs.size() <= 1 || fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) ||
708         fg->stage() != -1) {
709       continue;
710     }
711     auto &cloned_nodes = cloner.cloned_nodes();
712     for (auto &fv : fg->parameter_obj_nodes()) {
713       TraceGuard guard(std::make_shared<TraceCombileLikeGraphs>(fg->output()->debug_info()));
714       auto param = base_graph->add_parameter();
715       MS_EXCEPTION_IF_NULL(resource->manager());
716       auto &node_users = resource->manager()->node_users()[fv];
717       for (auto &n : node_users) {
718         // If the user is not in this graph, no need to change.
719         auto iter = cloned_nodes.find(n.first);
720         if (iter == cloned_nodes.end()) {
721           continue;
722         }
723         auto repl_n = iter->second->cast<CNodePtr>();
724         MS_EXCEPTION_IF_NULL(repl_n);
725         repl_n->set_input(IntToSize(n.second), param);
726       }
727     }
728     MS_LOG(DEBUG) << "Fg0 parameter_obj_nodes size :" << fg->parameter_obj_nodes().size();
729 
730     for (auto &g : graphs) {
731       TraceGuard guard(std::make_shared<TraceCopy>(fg->output()->debug_info()));
732       auto &fvs = g->parameter_obj_nodes();
733       std::vector<AnfNodePtr> new_node_inputs;
734       new_node_inputs.push_back(NewValueNode(base_graph));
735       for (auto &p : g->parameters()) {
736         AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p);
737         new_node_inputs.push_back(para_after_cast);
738       }
739       (void)new_node_inputs.insert(new_node_inputs.end(), fvs.cbegin(), fvs.cend());
740       AnfNodePtr out = g->NewCNodeBefore(g->get_return(), new_node_inputs);
741       g->set_output(out);
742       const int recursive_level = 4;
743       MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(recursive_level);
744     }
745     MS_LOG(DEBUG) << "End combine graph:" << it->first;
746   }
747   return true;
748 }
749 
750 namespace {
751 // Get all the trainable parameters of the reusable cell.
GenerateTopGraphParams(const FuncGraphPtr & fg,std::vector<AnfNodePtr> * params,const FuncGraphPtr & top_func_graph)752 void GenerateTopGraphParams(const FuncGraphPtr &fg, std::vector<AnfNodePtr> *params,
753                             const FuncGraphPtr &top_func_graph) {
754   MS_LOG(DEBUG) << "enter GenerateTopGraphParams: " << fg->ToString();
755   auto obj_value = fg->python_obj();
756   MS_EXCEPTION_IF_NULL(obj_value);
757   auto wrapper = dyn_cast_ptr<parse::PyObjectWrapper>(obj_value);
758   MS_EXCEPTION_IF_NULL(wrapper);
759   auto obj = wrapper->obj();
760   auto trainable_parameters = py::getattr(obj, "parameters_and_names", py::none())();
761   for (auto tr : trainable_parameters) {
762     auto item = py::cast<py::tuple>(tr);
763     auto value = item[1];
764     auto par_name = item[0].cast<std::string>();
765     auto parameter_name = py::getattr(value, "name", py::str(par_name)).cast<std::string>();
766     auto exist_fv = top_func_graph->GetParameterByName(parameter_name);
767     if (exist_fv) {
768       params->push_back(exist_fv);
769       MS_LOG(DEBUG) << "exist: " << parameter_name;
770     } else {
771       auto fv = top_func_graph->AddFvParameter(parameter_name, parse::GetParameterValue(value));
772       auto context = parallel::ParallelContext::GetInstance();
773       if (context != nullptr && fv->has_default()) {
774         auto fv_abs = pipeline::GetDefaultValueAbstract(fv);
775         context->ParallelParameterContextRestoreShape(top_func_graph, fv, fv_abs);
776         fv->set_abstract(fv_abs);
777       }
778       MS_LOG(DEBUG) << "New: " << parameter_name;
779       params->push_back(fv);
780     }
781   }
782   MS_LOG(DEBUG) << "finish GenerateTopGraphParams: " << fg->ToString();
783 }
784 
UpdateCellFuncGraph(const FuncGraphPtr & func_graph,const FuncGraphPtr & reusing_graph,const FuncGraphPtr & top_func_graph)785 void UpdateCellFuncGraph(const FuncGraphPtr &func_graph, const FuncGraphPtr &reusing_graph,
786                          const FuncGraphPtr &top_func_graph) {
787   std::vector<AnfNodePtr> new_node_inputs;
788   new_node_inputs.push_back(NewValueNode(reusing_graph));
789   std::vector<AnfNodePtr> fvs;
790   GenerateTopGraphParams(func_graph, &fvs, top_func_graph);
791   (void)new_node_inputs.insert(new_node_inputs.end(), fvs.rbegin(), fvs.rend());
792   auto params = func_graph->parameters();
793   (void)new_node_inputs.insert(new_node_inputs.end(), params.begin(), params.end());
794   AnfNodePtr out = func_graph->NewCNodeInOrder(new_node_inputs);
795   out->set_abstract(func_graph->output()->abstract());
796   func_graph->set_output(out);
797 }
798 
GeneralizeReusingGraph(const FuncGraphPtr & func_graph,const FuncGraphPtr & top_func_graph)799 void GeneralizeReusingGraph(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_func_graph) {
800   FuncGraphPtr fg = func_graph;
801   FuncGraphVector func_graphs = {fg};
802   Cloner cloner(func_graphs, false, false, true, std::make_shared<TraceCopy>(), std::make_shared<TraceGraphReusing>());
803   cloner.Run();
804   auto cloned_fg_iter = cloner.cloned_func_graphs().find(fg);
805   if (cloned_fg_iter == cloner.cloned_func_graphs().end()) {
806     MS_LOG(INTERNAL_EXCEPTION) << "Clone func graph failed! " << fg->ToString();
807   }
808   auto reusing_graph = cloned_fg_iter->second;
809   auto &cloned_nodes = cloner.cloned_nodes();
810   auto manager = fg->manager();
811   std::vector<AnfNodePtr> fv_params;
812   GenerateTopGraphParams(fg, &fv_params, top_func_graph);
813   for (auto &fv : fv_params) {
814     auto param = reusing_graph->InsertFrontParameter();
815     const auto &top_param = fv->cast<ParameterPtr>();
816     std::string name = "CR_" + top_param->name();
817     param->debug_info()->set_name(name);
818     param->set_name(name);
819     param->set_abstract(top_param->abstract());
820     auto &node_users = manager->node_users()[fv];
821     for (auto &n : node_users) {
822       auto iter = cloned_nodes.find(n.first);
823       if (iter == cloned_nodes.end()) {
824         continue;
825       }
826       auto repl_n = iter->second->cast<CNodePtr>();
827       MS_EXCEPTION_IF_NULL(repl_n);
828       repl_n->set_input(IntToSize(n.second), param);
829     }
830   }
831 
832   if (func_graph->has_attr(FUNC_GRAPH_FLAG_NO_INLINE)) {
833     reusing_graph->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, func_graph->has_flag(FUNC_GRAPH_FLAG_NO_INLINE));
834   } else {
835     reusing_graph->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, true);
836     reusing_graph->set_flag(FUNC_GRAPH_FLAG_CELL_REUSE, true);
837   }
838 
839   // Update call nodes
840   auto no_inline_flag = reusing_graph->has_flag(FUNC_GRAPH_FLAG_NO_INLINE);
841   auto cnodes_index = fg->func_graph_cnodes_index();
842   for (auto &cnode_index : cnodes_index) {
843     MS_EXCEPTION_IF_NULL(cnode_index.first);
844     auto old_cnode = cnode_index.first->first->cast<CNodePtr>();
845     MS_EXCEPTION_IF_NULL(old_cnode);
846     auto cell_func_graph = old_cnode->func_graph();
847     MS_EXCEPTION_IF_NULL(cell_func_graph);
848     UpdateCellFuncGraph(cell_func_graph, reusing_graph, top_func_graph);
849 
850     // optimize FuncGraph::scope() performance
851     cell_func_graph->set_flag(FUNC_GRAPH_FLAG_NO_CHILD_GRAPH, no_inline_flag);
852   }
853 }
854 
SetCalledSubGraphMixedPrecisionFlag(const FuncGraphPtr & func_graph)855 void SetCalledSubGraphMixedPrecisionFlag(const FuncGraphPtr &func_graph) {
856   FuncGraphPtr fp16_mixed_precision_fg;
857   FuncGraphPtr fp32_mixed_precision_fg;
858   FuncGraphPtr bf16_mixed_precision_fg;
859   // Find the first subgraph which has mixed precision flag.
860   for (auto &item : func_graph->func_graphs_used()) {
861     if (item.first->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
862       fp16_mixed_precision_fg = item.first;
863     }
864     if (item.first->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
865       fp32_mixed_precision_fg = item.first;
866     }
867     if (item.first->has_flag(GRAPH_FLAG_MIX_PRECISION_BF16)) {
868       bf16_mixed_precision_fg = item.first;
869     }
870     if ((fp32_mixed_precision_fg != nullptr) || (fp16_mixed_precision_fg != nullptr) ||
871         (bf16_mixed_precision_fg != nullptr)) {
872       break;
873     }
874   }
875 
876   // Add mixed precision flag to new subgraph which call subgraph in set.
877   if (fp16_mixed_precision_fg != nullptr) {
878     for (auto sub_fg : fp16_mixed_precision_fg->func_graphs_used_total()) {
879       sub_fg->set_flag(GRAPH_FLAG_MIX_PRECISION_FP16, true);
880     }
881   }
882   if (fp32_mixed_precision_fg != nullptr) {
883     for (auto sub_fg : fp32_mixed_precision_fg->func_graphs_used_total()) {
884       sub_fg->set_flag(GRAPH_FLAG_MIX_PRECISION_FP32, true);
885     }
886   }
887   if (bf16_mixed_precision_fg != nullptr) {
888     for (auto sub_fg : bf16_mixed_precision_fg->func_graphs_used_total()) {
889       sub_fg->set_flag(GRAPH_FLAG_MIX_PRECISION_BF16, true);
890     }
891   }
892 }
893 }  // namespace
894 
895 // Make the reusable cell to be the reusable function graph.
GraphReusingAction(const ResourcePtr & resource)896 bool GraphReusingAction(const ResourcePtr &resource) {
897   MS_EXCEPTION_IF_NULL(resource);
898   bool cell_reused = false;
899   auto func_graph = resource->func_graph();
900   std::multimap<int, FuncGraphPtr> order_fgs;
901   for (auto &fg : func_graph->func_graphs_used_total()) {
902     auto order_value = fg->get_attr(FUNC_GRAPH_FLAG_CELL_LAZY_INLINE_ORDER);
903     if (order_value == nullptr) {
904       continue;
905     }
906     fg->erase_flag(FUNC_GRAPH_FLAG_CELL_LAZY_INLINE_ORDER);
907     order_fgs.insert(std::make_pair(GetValue<int>(order_value), fg));
908   }
909   for (auto it = order_fgs.rbegin(); it != order_fgs.rend(); ++it) {
910     MS_LOG(INFO) << "Lazy_inline graph: " << it->second->ToString() << " , order: " << it->first;
911     GeneralizeReusingGraph(it->second, func_graph);
912     cell_reused = true;
913   }
914   if (!cell_reused) {
915     return true;
916   }
917 
918   auto context = MsContext::GetInstance();
919   MS_EXCEPTION_IF_NULL(context);
920   const bool enable_ge = context->backend_policy() == "ge";
921   const bool force_no_inline = common::IsDisableRuntimeConfig(common::kRuntimeInline);
922   context->SetCellReuseLevel(CellReuseLevel::kNoCellReuse);
923 
924   MS_LOG(INFO) << "Cell reuse(@lazy_inline) actually takes effect.";
925   auto cell_reuse_level =
926     (enable_ge && !context->IsKByKExecutorMode()) ? CellReuseLevel::kNoInline : CellReuseLevel::kLazyInline;
927   if (force_no_inline) {
928     cell_reuse_level = CellReuseLevel::kNoInline;
929   }
930   context->SetCellReuseLevel(cell_reuse_level);
931 
932   return true;
933 }
934 
935 // Used for excluding the func graphs in VMap.
UsedByVmap(const FuncGraphPtr & func_graph)936 bool UsedByVmap(const FuncGraphPtr &func_graph) {
937   const auto &cnodes_index = func_graph->func_graph_cnodes_index();
938   if (cnodes_index.empty()) {
939     return false;
940   }
941   const auto matcher = [&func_graph](const std::pair<const CNodeIndexPairPtr, int64_t> &cnode_index) {
942     const auto &cnode = cnode_index.first->first;
943     const auto &vmap_meta = GetCNodeValueWithoutDoSignature(cnode);
944     if (vmap_meta != nullptr && vmap_meta->isa<prim::VmapOperation>()) {
945       MS_LOG(DEBUG) << "Found VMap CNode: " << cnode->DebugString();
946       return true;
947     }
948     // The func graph is used in MakeTuple or UnpackGraph.
949     const auto user_matcher = [](const FuncGraphPtr &func_graph, const AnfNodePtr &cnode) {
950       auto manager = func_graph->manager();
951       MS_EXCEPTION_IF_NULL(manager);
952       auto &users = manager->node_users()[cnode];
953       for (const auto &user : users) {
954         const auto &user_vmap_meta = GetCNodeValueWithoutDoSignature(user.first);
955         if (user_vmap_meta != nullptr && user_vmap_meta->isa<prim::VmapOperation>()) {
956           MS_LOG(DEBUG) << "Found VMap CNode: " << user.first->DebugString();
957           return true;
958         }
959       }
960       return false;
961     };
962     const auto unpack_graph_prim = GetCNodePrimitive(cnode);
963     if (unpack_graph_prim != nullptr && unpack_graph_prim->isa<prim::UnpackGraphPrimitive>()) {
964       return user_matcher(func_graph, cnode);
965     }
966     if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
967       return user_matcher(func_graph, cnode);
968     }
969     // Deal with F.vmap(fn, ...) in construct().
970     // Not check fn passed from nested func graph calls.
971     if (cnode_index.first->second == 1) {
972       const auto vmap_func = GetCNodeFuncGraph(cnode);
973       if (vmap_func == nullptr) {
974         return false;
975       }
976       auto first_param = vmap_func->parameters()[0];
977       return user_matcher(func_graph, first_param);
978     }
979     return false;
980   };
981   return std::any_of(cnodes_index.cbegin(), cnodes_index.cend(), matcher);
982 }
983 
PreCConvAction(const ResourcePtr & resource)984 bool PreCConvAction(const ResourcePtr &resource) {
985   static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
986   if (!enable_pre_lift) {
987     return true;
988   }
989   MS_EXCEPTION_IF_NULL(resource);
990   MS_EXCEPTION_IF_NULL(resource->func_graph());
991   FuncGraphPtr func_graph = resource->func_graph();
992   FuncGraphPtr new_fg = LiftingClone(func_graph, false, UsedByVmap);
993   resource->set_func_graph(new_fg);
994   return GradPartialTransformPass(resource);
995 }
996 
SymbolResolveAction(const ResourcePtr & resource)997 bool SymbolResolveAction(const ResourcePtr &resource) {
998   MS_EXCEPTION_IF_NULL(resource);
999   if (resource->manager() == nullptr) {
1000     MS_LOG(INTERNAL_EXCEPTION) << "SymbolResolve error, manager is null";
1001   }
1002   auto func_graph = resource->func_graph();
1003   if (func_graph == nullptr) {
1004     MS_LOG(INTERNAL_EXCEPTION) << "SymbolResolve error, graph is null";
1005   }
1006   bool ret = parse::ResolveFuncGraph(func_graph, resource);
1007   // Remove unused nodes in cnode order list,
1008   // and check isolated side-effect nodes.
1009   if (func_graph != nullptr) {
1010     func_graph->EraseUnusedNodeInOrder();
1011     for (auto fg : func_graph->func_graphs_used_total()) {
1012       if (fg != nullptr) {
1013         fg->EraseUnusedNodeInOrder();
1014       }
1015     }
1016   }
1017   return ret;
1018 }
1019 
SetMixedPrecisionAction(const ResourcePtr & resource)1020 bool SetMixedPrecisionAction(const ResourcePtr &resource) {
1021   if (resource->manager() == nullptr) {
1022     MS_LOG(EXCEPTION) << "SetMixedPrecisionAction error, manager is null";
1023   }
1024   auto func_graph = resource->func_graph();
1025   if (func_graph == nullptr) {
1026     MS_LOG(EXCEPTION) << "SetMixedPrecisionAction error, graph is null";
1027   }
1028   SetCalledSubGraphMixedPrecisionFlag(func_graph);
1029   MS_LOG(DEBUG) << "Finish set mixed Precision flag in subgraph. ";
1030   return true;
1031 }
1032 
AutoMonadAction(const ResourcePtr & resource)1033 bool AutoMonadAction(const ResourcePtr &resource) {
1034   MS_EXCEPTION_IF_NULL(resource);
1035   if (resource->manager() == nullptr) {
1036     MS_LOG(INTERNAL_EXCEPTION) << "Auto-Monad failed, manager is null";
1037   }
1038   auto func_graph = resource->func_graph();
1039   if (func_graph == nullptr) {
1040     MS_LOG(INTERNAL_EXCEPTION) << "Auto-Monad failed, graph is null";
1041   }
1042   (void)pipeline::AutoMonad(func_graph);
1043   return true;
1044 }
1045 
OrderEnforceAction(const ResourcePtr & resource)1046 bool OrderEnforceAction(const ResourcePtr &resource) {
1047   MS_EXCEPTION_IF_NULL(resource);
1048   if (resource->manager() == nullptr) {
1049     MS_LOG(INTERNAL_EXCEPTION) << "Order-Enforce error, manager is null";
1050   }
1051   auto func_graph = resource->func_graph();
1052   if (func_graph == nullptr) {
1053     MS_LOG(INTERNAL_EXCEPTION) << "Order-Enforce error, graph is null";
1054   }
1055   pipeline::OrderEnforce(func_graph);
1056   return true;
1057 }
1058 
1059 // Get abstract of the default value in the given parameter.
GetDefaultValueAbstract(const ParameterPtr & param)1060 AbstractBasePtr GetDefaultValueAbstract(const ParameterPtr &param) {
1061   auto value = param->default_param();
1062   MS_EXCEPTION_IF_NULL(value);
1063   auto value_abs = value->ToAbstract();
1064   MS_EXCEPTION_IF_NULL(value_abs);
1065   if (value_abs->isa<abstract::AbstractMapTensor>()) {
1066     // Return AbstractMapTensor for map parameter.
1067     return value_abs;
1068   }
1069   // Make an AbstractRefTensor for the tensor value.
1070   auto abs_tensor = value_abs->cast<abstract::AbstractTensorPtr>();
1071   MS_EXCEPTION_IF_NULL(abs_tensor);
1072   auto ref_key = std::make_shared<RefKey>(param->name());
1073   return std::make_shared<abstract::AbstractRefTensor>(abs_tensor, ref_key);
1074 }
1075 
1076 namespace {
GetArgsAbs(const ResourcePtr & resource)1077 abstract::AbstractBasePtrList GetArgsAbs(const ResourcePtr &resource) {
1078   FuncGraphPtr func_graph = resource->func_graph();
1079   abstract::AbstractBasePtrList args_abs = resource->args_abs();
1080 
1081   // Parallel checking.
1082   auto context = parallel::ParallelContext::GetInstance();
1083   MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
1084 
1085   // Handle the Parameter from FV inputs.
1086   for (const auto &param : func_graph->parameters()) {
1087     auto param_node = std::static_pointer_cast<Parameter>(param);
1088     MS_EXCEPTION_IF_NULL(param_node);
1089     if (param_node->has_default()) {
1090       auto param_abs = GetDefaultValueAbstract(param_node);
1091       context->ParallelParameterContextRestoreShape(func_graph, param_node, param_abs);
1092       (void)args_abs.emplace_back(param_abs);
1093     }
1094   }
1095   return args_abs;
1096 }
1097 }  // namespace
1098 
TypeInferenceAction(const ResourcePtr & resource)1099 bool TypeInferenceAction(const ResourcePtr &resource) {
1100   EventMessage::PrintCompileStatusMessage("Start performing static analysis and type inference.");
1101   MS_EXCEPTION_IF_NULL(resource);
1102   if (resource->func_graph() == nullptr) {
1103     MS_LOG(INTERNAL_EXCEPTION) << "AbstractSpecialize error";
1104   }
1105   SetMindIRLoadFlag(resource);
1106   // Abstract analyze
1107   auto engine = resource->engine();
1108   MS_EXCEPTION_IF_NULL(engine);
1109 
1110   // Check isolated side-effect nodes.
1111   engine->set_check_side_effect(true);
1112   // Analyze
1113   (void)profiler::CollectHostInfo(kCompiler, kTypeInference, kAbstractAnalyze, 0, 0, 0);
1114   AnalysisResult result;
1115   {
1116     MsProfileStatGuard stat_guard("type_inference.infer");
1117     result = AbstractAnalyze(resource->engine(), resource->func_graph(), GetArgsAbs(resource), resource->is_load());
1118   }
1119   (void)profiler::CollectHostInfo(kCompiler, kTypeInference, kAbstractAnalyze, 0, 0, 1);
1120   // Specialize
1121   (void)profiler::CollectHostInfo(kCompiler, kTypeInference, kProgramSpecialize, 0, 0, 0);
1122   FuncGraphPtr new_fg;
1123   {
1124     MsProfileStatGuard stat_guard("type_inference.specialize");
1125     new_fg = ProgramSpecialize(resource->engine(), result.context->func_graph(), result.context);
1126   }
1127   (void)profiler::CollectHostInfo(kCompiler, kTypeInference, kProgramSpecialize, 0, 0, 1);
1128   // Update the top func graph with the specialized graph.
1129   parse::Parser::UpdateTopFuncGraph(new_fg);
1130   resource->set_func_graph(new_fg);
1131   engine->set_check_side_effect(false);
1132 
1133   // Remove unused nodes in cnode order list, this is prepared for auto-monad.
1134   if (new_fg) {
1135     new_fg->EraseUnusedNodeInOrder();
1136     for (auto fg : new_fg->func_graphs_used_total()) {
1137       if (fg) {
1138         fg->EraseUnusedNodeInOrder();
1139       }
1140     }
1141   }
1142 
1143   UpdateFuncGraphParameter(new_fg, resource->arguments());
1144   MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
1145   return true;
1146 }
1147 
OptimizeAction(const ResourcePtr & resource,const std::vector<PassItem> & passes)1148 bool OptimizeAction(const ResourcePtr &resource, const std::vector<PassItem> &passes) {
1149   MS_EXCEPTION_IF_NULL(resource);
1150   size_t counter = 0;
1151   for (auto &pass : passes) {
1152     ProcessStatus::GetInstance().RecordStart(pass.first);
1153     (void)profiler::CollectHostInfo(kCompiler, kOptimize, pass.first, 0, 0, 0);
1154     auto profile_context = MsProfile::GetProfile()->Step(pass.first);
1155     auto pass_func = [&pass, &resource, &counter]() {
1156       MS_LOG(DEBUG) << "Pass " << pass.first << " start ...";
1157       auto result = pass.second(resource);
1158       if (!result) {
1159         MS_LOG(INTERNAL_EXCEPTION) << "Pass running to end, failed in pass:" << pass.first;
1160       }
1161 #ifdef ENABLE_DUMP_IR
1162       auto context = MsContext::GetInstance();
1163       MS_EXCEPTION_IF_NULL(context);
1164       if (context->CanDump(kIntroductory) && resource->func_graph() != nullptr) {
1165         auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first;
1166         auto func_graph = resource->func_graph();
1167         MS_EXCEPTION_IF_NULL(func_graph);
1168         static const auto switch_order = (common::GetEnv("MS_DEV_SAVE_GRAPHS_SORT_MODE") == "1");
1169         if (switch_order) {
1170           ExportIR(fg_name + ".ir", func_graph);
1171         } else {
1172           DumpIR(fg_name + ".ir", func_graph);
1173         }
1174         if (context->CanDump(kFully)) {
1175           draw::Draw(fg_name + ".dot", func_graph);
1176         }
1177         MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
1178       }
1179 #endif
1180       counter++;
1181       MS_LOG(DEBUG) << "Pass " << pass.first << " end.";
1182     };
1183     ProfileExecute(profile_context, pass_func);
1184     (void)profiler::CollectHostInfo(kCompiler, kOptimize, pass.first, 0, 0, 1);
1185     ProcessStatus::GetInstance().RecordEnd();
1186   }
1187 
1188   return true;
1189 }
1190 
OptInlineAction(const ResourcePtr & resource)1191 bool OptInlineAction(const ResourcePtr &resource) {
1192   if (parallel::ParallelContext::GetInstance()->parallel_mode() == "semi_auto_parallel" ||
1193       parallel::ParallelContext::GetInstance()->parallel_mode() == "auto_parallel") {
1194     return OptimizeAction(resource, kInlinePasses);
1195   }
1196   return true;
1197 }
1198 
VmOptimizeAction(const ResourcePtr & resource)1199 bool VmOptimizeAction(const ResourcePtr &resource) {
1200   EventMessage::PrintCompileStatusMessage("Start performing graph optimization.");
1201 #if defined(__linux__) && defined(WITH_BACKEND)
1202   if (ps::PSContext::instance()->is_ps_mode()) {
1203     (void)kVmPasses.emplace_back(PassItem("server_communication_op_fusion", [](const ResourcePtr &res) -> bool {
1204       MS_EXCEPTION_IF_NULL(res);
1205       return ps::Util::FuseServerCommOps(res->func_graph());
1206     }));
1207   }
1208 #endif
1209   auto ret = OptimizeAction(resource, kVmPasses);
1210   TraceManager::CloseParserDebugInfoFlag();
1211   return ret;
1212 }
1213 
IsCtrlSink()1214 static bool IsCtrlSink() {
1215   auto ms_ctx = MsContext::GetInstance();
1216   if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) {
1217     return false;
1218   }
1219 
1220   std::string device_target = ms_ctx->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1221   if (device_target != kAscendDevice) {
1222     return false;
1223   }
1224 
1225   if (!ms_ctx->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
1226     return false;
1227   }
1228 
1229   return ms_ctx->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK);
1230 }
1231 
CheckGraphOutputConstOrParameter(const FuncGraphPtr & func_graph)1232 bool CheckGraphOutputConstOrParameter(const FuncGraphPtr &func_graph) {
1233   if (func_graph != nullptr) {
1234     AnfNodePtr output = func_graph->output();
1235     if (output != nullptr && (output->isa<ValueNode>() || output->isa<Parameter>())) {
1236       return true;
1237     }
1238   }
1239   return false;
1240 }
1241 
GetJitBpropGraph(const ResourcePtr & resource)1242 bool GetJitBpropGraph(const ResourcePtr &resource) {
1243   // This function only works in Pynative mode. The func_graph is decorated with 'jit'.
1244   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
1245     return true;
1246   }
1247   return pynative::PyNativeExecutor::GetInstance()->grad_executor()->jit()->GetJitGradGraph(resource);
1248 }
1249 
RewriterAfterOptAPassAfterJitBprop(const ResourcePtr & resource)1250 bool RewriterAfterOptAPassAfterJitBprop(const ResourcePtr &resource) {
1251   // This function is only used to convert unsupported syntax into PyExecute nodes through Fallback,
1252   // when the forward graph is decorated with 'jit', and is derivative in pynative mode.
1253   auto context = MsContext::GetInstance();
1254   MS_EXCEPTION_IF_NULL(context);
1255   if (context->not_convert_jit()) {
1256     context->set_not_convert_jit(false);
1257     MS_EXCEPTION_IF_NULL(resource);
1258     FuncGraphPtr func_graph = resource->func_graph();
1259     MS_EXCEPTION_IF_NULL(func_graph);
1260     (void)mindspore::opt::RewriterAfterOptA(func_graph, resource);
1261     UpdateArgsSpec(func_graph, resource);
1262   }
1263   return true;
1264 }
1265 
EliminateSpecialOpNode(const ResourcePtr & resource)1266 bool EliminateSpecialOpNode(const ResourcePtr &resource) {
1267   MS_EXCEPTION_IF_NULL(resource);
1268   if (resource->manager() == nullptr) {
1269     MS_LOG(INTERNAL_EXCEPTION) << "PynativeElimOpt error, manager is null.";
1270   }
1271   if (resource->func_graph() == nullptr) {
1272     MS_LOG(INTERNAL_EXCEPTION) << "PynativeElimOpt error, graph is null.";
1273   }
1274   return EliminateSpecialOpOptPass(resource);
1275 }
1276 
HasIncorporateCall(const std::vector<AnfNodePtr> & all_nodes)1277 bool HasIncorporateCall(const std::vector<AnfNodePtr> &all_nodes) {
1278   for (const auto &node : all_nodes) {
1279     if (node == nullptr || !node->isa<CNode>()) {
1280       continue;
1281     }
1282     auto cnode = node->cast<CNodePtr>();
1283     if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
1284       auto partial_function = cnode->input(kPartialGraphIndex);
1285       if (!IsValueNode<FuncGraph>(partial_function)) {
1286         MS_LOG(INFO) << "Partial has indirect call: " << cnode->DebugString();
1287         return true;
1288       }
1289       continue;
1290     }
1291     if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
1292       const auto &switch_inputs = cnode->inputs();
1293       if (std::any_of(switch_inputs.begin() + kSwitchTrueBranchIndex, switch_inputs.end(), [](const AnfNodePtr &input) {
1294             return !IsPrimitiveCNode(input, prim::kPrimPartial) && !IsValueNode<FuncGraph>(input);
1295           })) {
1296         MS_LOG(INFO) << "Switch has indirect call: " << cnode->DebugString();
1297         return true;
1298       }
1299       continue;
1300     }
1301     if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
1302       auto make_tuple = cnode->input(kSwitchLayerBranchesIndex);
1303       if (!IsPrimitiveCNode(make_tuple, prim::kPrimMakeTuple)) {
1304         MS_LOG(INTERNAL_EXCEPTION) << "SwitchLayer input2 should be make_tuple, but got: " << make_tuple->DebugString();
1305       }
1306       const auto &make_tuple_inputs = make_tuple->cast<CNodePtr>()->inputs();
1307       if (std::any_of(make_tuple_inputs.begin() + 1, make_tuple_inputs.end(), [](const AnfNodePtr &input) {
1308             return !IsPrimitiveCNode(input, prim::kPrimPartial) && !IsValueNode<FuncGraph>(input);
1309           })) {
1310         MS_LOG(INFO) << "SwitchLayer has indirect call: " << cnode->DebugString();
1311         return true;
1312       }
1313       continue;
1314     }
1315     if (common::AnfAlgo::HasIncorporateCallNode(cnode)) {
1316       return true;
1317     }
1318   }
1319   return false;
1320 }
1321 
ExistTarget(const std::vector<AnfNodePtr> & all_nodes,const std::string & target)1322 bool ExistTarget(const std::vector<AnfNodePtr> &all_nodes, const std::string &target) {
1323   for (const auto &node : all_nodes) {
1324     if (node == nullptr || !node->isa<CNode>()) {
1325       continue;
1326     }
1327     if (GetCNodeTarget(node) == target) {
1328       return true;
1329     }
1330   }
1331   return false;
1332 }
1333 
1334 // If the return value of subgraph is Ref in control flow scenarios, should run graph mode with kernelbykernel.
ExistSwitchRef(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & all_nodes)1335 bool ExistSwitchRef(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &all_nodes) {
1336   // %1 = switch(cond, func1, func2)
1337   // %2 = %1()  if the abstract of the node is AbstractRefTensor or Tuple/List(AbstractRefTensor, ...), return true.
1338   auto manager = func_graph->manager();
1339   MS_EXCEPTION_IF_NULL(manager);
1340   auto &node_users = manager->node_users();
1341   auto context_ptr = MsContext::GetInstance();
1342   MS_EXCEPTION_IF_NULL(context_ptr);
1343   std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1344   for (const auto &node : all_nodes) {
1345     if (!IsPrimitiveCNode(node, prim::kPrimSwitch)) {
1346       continue;
1347     }
1348     auto iter = node_users.find(node);
1349     if (iter != node_users.end()) {
1350       auto &users = iter->second;
1351       for (auto &user : users) {
1352         auto &user_node = user.first;
1353         if (common::AnfAlgo::HasAbstractRef(user_node) || common::AnfAlgo::SequenceHasAbstractRef(user_node)) {
1354           if (device_target == kAscendDevice) {
1355             MS_LOG(WARNING) << "On the Ascend platform, if you read-only access to the parameter, "
1356                             << "you can take the value of the parameter, so that the system can do more optimization. "
1357                             << "For example, change 'return param' to 'return param.value()'\n"
1358                             << "Please check your code:" << trace::GetDebugInfoStr(user_node->debug_info());
1359           }
1360           return true;
1361         }
1362       }
1363     }
1364   }
1365   return false;
1366 }
1367 
SetModeForControlFlow(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & all_nodes,bool pynative_mode,compile::Backend * backend_ptr)1368 bool SetModeForControlFlow(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &all_nodes, bool pynative_mode,
1369                            compile::Backend *backend_ptr) {
1370   auto context_ptr = MsContext::GetInstance();
1371   MS_EXCEPTION_IF_NULL(context_ptr);
1372   MS_EXCEPTION_IF_NULL(func_graph);
1373   MS_EXCEPTION_IF_NULL(backend_ptr);
1374   auto set_ctx = [&context_ptr, &backend_ptr](bool task_sink, bool is_multi_graph_sink, bool enable_loop_sink) {
1375     context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, task_sink);
1376     context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, is_multi_graph_sink);
1377     context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, enable_loop_sink);
1378     backend_ptr->set_is_multi_graph_sink(is_multi_graph_sink);
1379   };
1380   // GRAPH | Closure\ENV\While scenario : KernelByKernel path in MindRT.
1381   auto graphs = func_graph->func_graphs_used_total();
1382   (void)graphs.insert(func_graph);
1383   bool exist_control_flow = ExistControlFlow(func_graph);
1384   bool exist_func = exist_control_flow && HasIncorporateCall(all_nodes);
1385   if (exist_func) {
1386     if (!pynative_mode) {
1387       MS_LOG(INFO) << "Run graph mode with sub graph sink because graph exist control flow and incorporate call.";
1388       set_ctx(true, false, false);
1389     } else {
1390       MS_LOG(INFO) << "Run graph mode with kernel by kernel because graph exist control flow and incorporate call.";
1391       set_ctx(false, false, false);
1392     }
1393     return false;
1394   }
1395   bool exist_while =
1396     std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
1397   MS_LOG(INFO) << func_graph->ToString() << " exist_while: " << exist_while;
1398   if (exist_while || ExistSwitchRef(func_graph, all_nodes)) {
1399     if (!pynative_mode) {
1400       MS_LOG(INFO) << "Run graph mode with sub graph sink because graph exist while or switch ref.";
1401       set_ctx(true, false, false);
1402     } else {
1403       MS_LOG(INFO) << "Run graph mode with kernel by kernel because graph exist while or switch ref.";
1404       set_ctx(false, false, false);
1405     }
1406     return false;
1407   }
1408   // Multiple device targets scenario.
1409   if (func_graph->exist_multi_target()) {
1410     // Heterogeneous scenario + ControlFlow : KernelByKernel path in MindRT.
1411     if (exist_control_flow && pynative_mode) {
1412       MS_LOG(INFO) << "Run graph mode with kernel by kernel because graph exist multi device target and control flow.";
1413       set_ctx(false, false, false);
1414       return false;
1415     }
1416     // GRAPH | Heterogeneous scenario : No control flow, subgraph sink path in MindRT.
1417     MS_LOG(INFO) << "Run graph mode with subgraph sink because graph exist multi device target.";
1418     set_ctx(true, false, false);
1419     return false;
1420   }
1421   return true;
1422 }
1423 
IsCellReuse(const AnfNodePtr & input)1424 bool IsCellReuse(const AnfNodePtr &input) {
1425   if (IsValueNode<FuncGraph>(input)) {
1426     auto fg = GetValueNode<FuncGraphPtr>(input);
1427     MS_EXCEPTION_IF_NULL(fg);
1428     if (fg->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
1429       return true;
1430     }
1431   }
1432   return false;
1433 }
1434 
ProcessCanNotInline(const FuncGraphPtr & func_graph,const std::shared_ptr<MsContext> & context_ptr)1435 void ProcessCanNotInline(const FuncGraphPtr &func_graph, const std::shared_ptr<MsContext> &context_ptr) {
1436   auto graphs = func_graph->func_graphs_used_total();
1437   (void)graphs.insert(func_graph);
1438   bool exist_while =
1439     std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
1440   if (exist_while && context_ptr->CellReuseLevel() == CellReuseLevel::kLazyInline) {
1441     MS_LOG(INFO) << "Set no inline because graph has while.";
1442     context_ptr->SetCellReuseLevel(CellReuseLevel::kNoInline);
1443   }
1444 
1445   auto cant_inline_cell_reuse = [](const FuncGraphPtr &fg) -> bool {
1446     if (!fg->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
1447       return false;
1448     }
1449     MS_LOG(INFO) << "Cell reuse graph: " << fg->ToString();
1450     // cell reuse func graph has switch
1451     if (!fg->switch_nodes().empty()) {
1452       MS_LOG(INFO) << "Set no inline because cell reuse graph has switch, " << fg->ToString();
1453       return true;
1454     }
1455     // cell reuse sub graph has switch or cell reuse
1456     for (auto &sub_graph : fg->func_graphs_used_total()) {
1457       if (!sub_graph->switch_nodes().empty() || sub_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
1458         MS_LOG(INFO) << "Set no inline because cell reuse sub graph has switch or nested cell reuse, "
1459                      << sub_graph->ToString();
1460         return true;
1461       }
1462     }
1463     return false;
1464   };
1465   if (std::any_of(graphs.cbegin(), graphs.cend(), cant_inline_cell_reuse)) {
1466     MS_LOG(INFO) << "Set no inline because cell reuse graph has switch or nested cell reuse.";
1467     context_ptr->SetCellReuseLevel(CellReuseLevel::kNoInline);
1468   }
1469   if (!common::IsEnableRuntimeConfig(common::kRuntimeInline)) {
1470     const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
1471     size_t micro_num = 0;
1472     for (auto &node : all_nodes) {
1473       if (!node->isa<CNode>()) {
1474         continue;
1475       }
1476       auto cnode = node->cast<CNodePtr>();
1477       if (IsCellReuse(cnode->input(0))) {
1478         micro_num++;
1479       }
1480     }
1481     auto parallel_context = parallel::ParallelContext::GetInstance();
1482     MS_EXCEPTION_IF_NULL(parallel_context);
1483     auto stages = parallel_context->pipeline_stage_split_num();
1484     if (stages <= 1) {
1485       return;
1486     }
1487     MS_LOG(INFO) << "Cell reuse micro num: " << micro_num;
1488     if (micro_num > kLazyInlineThershold) {
1489       MS_LOG(INFO) << "Set no inline because cell reuse micro num is greater than " << kLazyInlineThershold
1490                    << ", micro num: " << micro_num;
1491       context_ptr->SetCellReuseLevel(CellReuseLevel::kNoInline);
1492     }
1493   }
1494 }
1495 
SetRunMode(const FuncGraphPtr & func_graph,compile::Backend * backend_ptr,std::string * kbk_reason)1496 void SetRunMode(const FuncGraphPtr &func_graph, compile::Backend *backend_ptr, std::string *kbk_reason) {
1497   auto context_ptr = MsContext::GetInstance();
1498   MS_EXCEPTION_IF_NULL(context_ptr);
1499   MS_EXCEPTION_IF_NULL(func_graph);
1500   MS_EXCEPTION_IF_NULL(backend_ptr);
1501   auto set_ctx = [&context_ptr, &backend_ptr](bool task_sink, bool is_multi_graph_sink, bool enable_loop_sink) {
1502     context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, task_sink);
1503     context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, is_multi_graph_sink);
1504     context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, enable_loop_sink);
1505     backend_ptr->set_is_multi_graph_sink(is_multi_graph_sink);
1506   };
1507   ProcessCanNotInline(func_graph, context_ptr);
1508   auto jit_level = pipeline::GetJitLevel();
1509   func_graph->set_attr(kAttrJitLevel, MakeValue<std::string>(jit_level));
1510   auto jit_config = PhaseManager::GetInstance().jit_config();
1511   jit_config[kAttrJitLevel] = context_ptr->GetJitLevel();
1512   graphkernel::GraphKernelFlags::SaveJitConfig(jit_config);
1513 
1514   const bool pynative_mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
1515   const auto &device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1516   if (pynative_mode && device_target != kAscendDevice) {
1517     return;
1518   }
1519   const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
1520   // GPU/CPU no need set any context.
1521   if (!ExistTarget(all_nodes, kAscendDevice)) {
1522     return;
1523   }
1524 
1525   // GRAPH | Single Op : KernelByKernel path in MindRT.
1526   if (context_ptr->IsKByKExecutorMode()) {
1527     if (kbk_reason != nullptr) {
1528       *kbk_reason = "Run graph mode with kernel by kernel by configuration.";
1529       MS_LOG(INFO) << *kbk_reason;
1530     }
1531     set_ctx(false, false, false);
1532     return;
1533   }
1534 
1535   // GRAPH | Dynamic Shape : KernelByKernel path in MindRT.
1536   if (common::AnfAlgo::IsDynamicGraph(func_graph) && (context_ptr->backend_policy() != "ge")) {
1537     if (kbk_reason != nullptr) {
1538       *kbk_reason =
1539         "Run graph mode with kernel by kernel because graph exist dynamic shape. Call "
1540         "'set_context(save_graphs=True)' to check graph irs.";
1541       MS_LOG(INFO) << *kbk_reason;
1542     }
1543     set_ctx(false, false, false);
1544     return;
1545   }
1546 
1547   // GRAPH | Dynamic Scalar : Dynamic scalar ops in graph.
1548   if (IsNeedBackoffGraph(func_graph) && !common::AnfAlgo::IsDynamicGraph(func_graph)) {
1549     if (kbk_reason != nullptr) {
1550       *kbk_reason = "Run graph mode with kernel by kernel because graph exist dynamic scalar ops.";
1551       MS_LOG(INFO) << *kbk_reason;
1552     }
1553     set_ctx(false, false, false);
1554     return;
1555   }
1556   if (!SetModeForControlFlow(func_graph, all_nodes, pynative_mode, backend_ptr)) {
1557     return;
1558   }
1559 
1560 #if defined(__linux__) && defined(WITH_BACKEND)
1561   if (ps::PSContext::instance()->cache_enable()) {
1562     MS_LOG(INFO) << "Run graph mode with subgraph sink because PS cache enable.";
1563     set_ctx(true, false, false);
1564     return;
1565   }
1566 #endif
1567 
1568   // GRAPH | normal network and if/for/switch scenario etc : MultiGraph path in MindRT.
1569   MS_LOG(INFO) << "Run graph mode with multi graph sink.";
1570   set_ctx(true, true, !pynative_mode);
1571   return;
1572 }
1573 
OriginSetRunMode(const ResourcePtr & resource)1574 void OriginSetRunMode(const ResourcePtr &resource) {
1575   FuncGraphPtr func_graph = resource->func_graph();
1576   MS_EXCEPTION_IF_NULL(func_graph);
1577   auto bc_ptr = resource->GetBackend();
1578   auto context_ptr = MsContext::GetInstance();
1579   std::string backend = MsContext::GetInstance()->backend_policy();
1580   MS_EXCEPTION_IF_NULL(context_ptr);
1581   auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
1582   if (func_graph->exist_multi_target() || !task_sink) {
1583     bc_ptr->set_is_multi_graph_sink(false);
1584     context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
1585     context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
1586   } else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
1587     std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1588     auto manager = func_graph->manager();
1589     auto graphs = manager->func_graphs();
1590     if (graphs.size() > 1 && device_target == kAscendDevice) {
1591       MS_LOG(INFO) << "This func_graph has control flow nodes, owns " << graphs.size() << " subgraphs.";
1592     }
1593     bool exist_while =
1594       std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
1595     if (device_target == kAscendDevice && backend != kMsVm && !exist_while) {
1596       MS_LOG(INFO) << "Run graph mode with multigraph sink.";
1597       bc_ptr->set_is_multi_graph_sink(true);
1598       context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
1599     } else {
1600       MS_LOG(INFO) << "Run graph mode with vm.";
1601       bc_ptr->set_is_multi_graph_sink(false);
1602       context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
1603       context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
1604     }
1605   }
1606 }
1607 
SetRunMode(const ResourcePtr & resource)1608 void SetRunMode(const ResourcePtr &resource) {
1609   MS_EXCEPTION_IF_NULL(resource);
1610   auto context_ptr = MsContext::GetInstance();
1611   MS_EXCEPTION_IF_NULL(context_ptr);
1612   // The root cause of KernelByKernel mode should be returned.
1613   std::string kbk_reason = "";
1614   if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1615     SetRunMode(resource->func_graph(), resource->GetBackend().get(), &kbk_reason);
1616   } else {
1617     OriginSetRunMode(resource);
1618   }
1619   auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
1620   auto is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
1621   auto enable_hccl = context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL);
1622   if ((!is_task_sink ||
1623        (context_ptr->IsKByKExecutorMode() && common::AnfAlgo::IsDynamicGraph(resource->func_graph()))) &&
1624       mode == kGraphMode && enable_hccl && !common::UseHostCollective() && common::GetEnv(kSimulationLevel).empty()) {
1625     MS_LOG(INTERNAL_EXCEPTION) << "Current execution mode is 'kernelbykernel', reason: " << kbk_reason
1626                                << ", but you're launching job using 'ranktable', which "
1627                                   "does not support 'kernelbykernel' mode.\n Please refer to link: "
1628                                   "https://www.mindspore.cn/tutorials/experts/en/master/parallel/startup_method.html "
1629                                   "and use 'Dynamic cluster'(suggested) or 'mpirun' to launch your job.";
1630   }
1631 }
1632 
TaskEmitAction(const ResourcePtr & resource)1633 bool TaskEmitAction(const ResourcePtr &resource) {
1634   EventMessage::PrintCompileStatusMessage("Start generating kernels.");
1635   MS_EXCEPTION_IF_NULL(resource);
1636   FuncGraphPtr func_graph = resource->func_graph();
1637   if (func_graph == nullptr) {
1638     MS_LOG(INTERNAL_EXCEPTION) << "TaskEmit args error";
1639   }
1640   auto context_ptr = MsContext::GetInstance();
1641   MS_EXCEPTION_IF_NULL(context_ptr);
1642   auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
1643   if (mode == kGraphMode && CheckGraphOutputConstOrParameter(func_graph)) {
1644     return true;
1645   }
1646 
1647   // In PyNative mode, multi target will generate in -1 shape in jit. But, jit in -1 shape will run as a call graph;
1648   // control flow not has flag kFlagJitCallGraph
1649   bool is_control_flow = !func_graph->func_graphs_used_total().empty();
1650   if (mode == kGraphMode || (mode == kPynativeMode && (func_graph->has_flag(kFlagJitCallGraph) || is_control_flow))) {
1651     func_graph->SetMultiTarget();
1652     if (func_graph->exist_multi_target() && DumpJsonParser::GetInstance().IsDumpEnabled()) {
1653       MS_LOG(WARNING) << "Multi device target is detected, CPU data is dumped in rank_0 directory";
1654     }
1655   }
1656   DisableMindRT(resource);
1657 
1658   SetRunMode(resource);
1659   auto bc_ptr = resource->GetBackend();
1660   MS_EXCEPTION_IF_NULL(bc_ptr);
1661   const auto &backend = context_ptr->backend_policy();
1662   // The graph compiling of mindRT.
1663   if ((backend == kMsConvert || backend == kGeVm) && context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1664     TaskEmitActionForMindRT(resource);
1665     return true;
1666   }
1667   // The graph compiling of control sink.
1668   if (IsCtrlSink() && (backend == kMsConvert || backend == kGeVm)) {
1669     auto graph_id = bc_ptr->CompileGraph(NOT_NULL(func_graph));
1670     resource->SetResult(kOutput, graph_id);
1671     return true;
1672   }
1673   std::vector<PrimitivePtr> cut_list = compile::GetNonlinearOps();
1674   if (bc_ptr->name() == kMsConvert || bc_ptr->name() == kGeVm) {
1675     cut_list = compile::GetMsNonlinearOps();
1676   }
1677   std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
1678   auto vm = compile->CompileAndLink(func_graph);
1679   resource->SetResult(kOutput, vm);
1680   return true;
1681 }
1682 
ExecuteAction(const ResourcePtr & resource)1683 bool ExecuteAction(const ResourcePtr &resource) {
1684   MS_EXCEPTION_IF_NULL(resource);
1685   if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
1686       CheckGraphOutputConstOrParameter(resource->func_graph())) {
1687     return true;
1688   }
1689   if (!resource->HasResult(kOutput)) {
1690     MS_LOG(INTERNAL_EXCEPTION) << "Execute args error";
1691   }
1692   std::string backend = MsContext::GetInstance()->backend_policy();
1693   // The graph running of mindRT.
1694   if ((backend == kMsConvert || backend == kGeVm) && MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1695     ExecuteActionForMindRT(resource);
1696     return true;
1697   }
1698 
1699   // The graph running of control sink.
1700   if (IsCtrlSink() && (backend == kMsConvert || backend == kGeVm)) {
1701     auto graph_id = resource->GetResult(kOutput).cast<GraphId>();
1702     auto bc_ptr = resource->GetBackend();
1703     compile::MsBackend *msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr).get();
1704     MS_EXCEPTION_IF_NULL(msbc_ptr);
1705     compile::VmEvalFuncPtr run =
1706       std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef {
1707         MS_LOG(INFO) << "Execute args size " << args.size();
1708         auto outs = msbc_ptr->RunGraph(graph_id, args);
1709         MS_LOG(DEBUG) << "out size " << outs.size();
1710         return outs[0];
1711       });
1712     resource->SetResult(kOutput, run);
1713     return true;
1714   }
1715 
1716   compile::FinalVMPtr vm = resource->GetResult(kOutput).cast<compile::FinalVMPtr>();
1717   if (vm == nullptr) {
1718     MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM";
1719     return true;
1720   }
1721   compile::VmEvalFuncPtr run =
1722     std::make_shared<compile::VmEvalFunc>(std::bind(&compile::FinalVM::Eval, vm, std::placeholders::_1));
1723   resource->SetResult(kOutput, run);
1724   return true;
1725 }
1726 
1727 #if defined(__linux__) && defined(WITH_BACKEND)
DistributedSplitAction(const ResourcePtr & resource)1728 bool DistributedSplitAction(const ResourcePtr &resource) {
1729   // Only run this action when the cluster is initialized.
1730   if (!distributed::cluster::ClusterContext::instance()->initialized()) {
1731     return true;
1732   }
1733   MS_EXCEPTION_IF_NULL(resource);
1734   FuncGraphPtr func_graph = resource->func_graph();
1735   auto node = distributed::cluster::ClusterContext::instance()->node();
1736   MS_EXCEPTION_IF_NULL(node);
1737   auto node_role = distributed::cluster::ClusterContext::instance()->node_role();
1738 
1739   parallel::GraphSplitterPtr splitter =
1740     std::make_shared<parallel::GraphSplitter>(func_graph, node->rank_id(), node_role);
1741   MS_EXCEPTION_IF_NULL(splitter);
1742   splitter->Run();
1743   // Renomalize: Infer shape and Set abstract for all nodes in graph.
1744   if (func_graph->has_flag(kFlagNeedRenormalize)) {
1745     abstract::AbstractBasePtrList args_abs;
1746     auto parameters = func_graph->parameters();
1747     (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs),
1748                          [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
1749     FuncGraphPtr new_fg = Renormalize(resource, func_graph, args_abs);
1750     resource->set_func_graph(new_fg);
1751     resource->set_args_abs(args_abs);
1752   }
1753   return true;
1754 }
1755 #endif
1756 
1757 // The parallel primitive related valuenode might be partitioned so that its value changes by device,
1758 // that will result in a synchronization error due to different executing order.
1759 // Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive,
1760 // the final solution will be proposed later as a parallel feature.
KeepValueNodeDuplication(const AnfNodePtr & value_node,const ResourcePtr & resource)1761 bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &resource) {
1762   MS_EXCEPTION_IF_NULL(resource);
1763   MS_EXCEPTION_IF_NULL(resource->manager());
1764   auto &node_users = resource->manager()->node_users();
1765   auto &users = node_users[value_node];
1766   auto used_by_keep_value_prim =
1767     std::any_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
1768       MS_EXCEPTION_IF_NULL(user.first);
1769       auto cnode = user.first->cast<CNodePtr>();
1770       if (cnode == nullptr) {
1771         return false;
1772       }
1773       auto prim_node = cnode->input(0);
1774       if (IsValueNode<Primitive>(prim_node)) {
1775         auto prim = GetValue<PrimitivePtr>(prim_node->cast<ValueNodePtr>()->value());
1776         MS_EXCEPTION_IF_NULL(prim);
1777         // value_node is referenced by some parallel primitive
1778         return prim->HasAttr("keep_value_node_input");
1779       }
1780       return false;
1781     });
1782   return used_by_keep_value_prim;
1783 }
1784 
RemoveValueNodeDuplicationsAction(const ResourcePtr & resource)1785 bool RemoveValueNodeDuplicationsAction(const ResourcePtr &resource) {
1786   MS_EXCEPTION_IF_NULL(resource);
1787   FuncGraphPtr func_graph = resource->func_graph();
1788   if (func_graph == nullptr) {
1789     MS_LOG(INTERNAL_EXCEPTION) << "Remove value node duplications error.";
1790   }
1791   auto manager = resource->manager();
1792   // Remove duplicated value nodes, due to replace operation, can't use reference.
1793   auto value_nodes = func_graph->value_nodes();
1794   HashCache hash_cache;
1795   HashValue hashes;
1796   for (const auto &value_pair : value_nodes) {
1797     if (KeepValueNodeDuplication(value_pair.first, resource)) {
1798       continue;
1799     }
1800     TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
1801   }
1802   return true;
1803 }
1804 
PipelineSplitAction(const ResourcePtr & resource)1805 bool PipelineSplitAction(const ResourcePtr &resource) { return PipelineSplitPass(resource); }
1806 
ParallelVirtualDatasetAction(const ResourcePtr & resource)1807 bool ParallelVirtualDatasetAction(const ResourcePtr &resource) { return ParallelVirtualDatasetPass(resource); }
1808 
AutoParallelSymbolWithReNormalizeAction(const ResourcePtr & resource)1809 bool AutoParallelSymbolWithReNormalizeAction(const ResourcePtr &resource) {
1810   return AutoParallelSymbolPassWithReNormalize(resource);
1811 }
PipelineSchedulerAction(const ResourcePtr & resource)1812 bool PipelineSchedulerAction(const ResourcePtr &resource) { return PipelineParallelScheduler(resource); }
1813 
AutoParallelAction(const ResourcePtr & resource)1814 bool AutoParallelAction(const ResourcePtr &resource) { return AutoParallelPass(resource); }
1815 
ValidateAction(const ResourcePtr & resource)1816 bool ValidateAction(const ResourcePtr &resource) {
1817   auto res = ValidatePass(resource);
1818 #ifdef DEBUG
1819   FuncGraphLoopBreaker::Inst().Dump();
1820 #endif
1821   return res;
1822 }
1823 
SetMindIRGraphAction(const ResourcePtr & resource)1824 bool SetMindIRGraphAction(const ResourcePtr &resource) {
1825   MS_EXCEPTION_IF_NULL(resource);
1826   resource->set_is_load(true);
1827   auto cell = py::cast<CellPtr>(resource->source_input());
1828   if (cell == nullptr) {
1829     MS_LOG(INTERNAL_EXCEPTION) << "The graph loaded from mindir is null.";
1830   }
1831   const std::string mindir_graph = "graph_load_from_mindir";
1832   auto obj = cell->GetAttr(mindir_graph);
1833   if (obj == nullptr) {
1834     MS_LOG(INTERNAL_EXCEPTION) << "The graph loaded from mindir is null. The cell has not attribute: " << mindir_graph;
1835   }
1836   auto fg = GetValue<FuncGraphPtr>(obj);
1837   if (fg == nullptr) {
1838     MS_LOG(INTERNAL_EXCEPTION) << "The graph loaded from mindir is null.";
1839   }
1840   resource->set_func_graph(fg);
1841   FuncGraphManagerPtr mng = fg->manager();
1842   if (mng == nullptr) {
1843     auto res_mng = resource->manager();
1844     MS_EXCEPTION_IF_NULL(res_mng);
1845     res_mng->Clear();
1846     res_mng->AddFuncGraph(fg);
1847   }
1848   abstract::AbstractBasePtrList broaded_args;
1849   const auto &args_abs_list = resource->args_abs();
1850   (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(broaded_args),
1851                        [](const AbstractBasePtr &arg) -> AbstractBasePtr {
1852                          MS_EXCEPTION_IF_NULL(arg);
1853                          if (arg->GetValueTrack() != kValueAny) {
1854                            return arg->Broaden();
1855                          }
1856                          return arg;
1857                        });
1858 
1859   abstract::AbstractBasePtrList func_args;
1860   const auto inputs = fg->get_inputs();
1861   (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(func_args),
1862                        [](const AnfNodePtr &arg) -> AbstractBasePtr {
1863                          MS_EXCEPTION_IF_NULL(arg);
1864                          auto abs = arg->abstract();
1865                          MS_EXCEPTION_IF_NULL(abs);
1866                          return abs->Broaden();
1867                        });
1868 
1869   bool is_equal_input_args = true;
1870   if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) {
1871     MS_LOG(INFO) << "The input arguments is not compatible with the function graph which has been exported before."
1872                  << "Please check the args is same with export.\n"
1873                  << "The export input argument size: " << func_args.size() << "\n"
1874                  << "The load input argument size: " << broaded_args.size() << "\n"
1875                  << "Export input args info: " << abstract::ArgsToString(func_args) << "\n"
1876                  << "The input args info: " << abstract::ArgsToString(broaded_args);
1877     is_equal_input_args = false;
1878   }
1879 
1880   if (!is_equal_input_args) {
1881     // Use InferMindir which will find c++ infer in eval_map and backend_eval_map;
1882     (void)InferMindir(resource->func_graph(), args_abs_list, true);
1883   }
1884   return true;
1885 }
1886 
CommonPipeline(bool trace_flag)1887 static std::vector<ActionItem> CommonPipeline(bool trace_flag) {
1888   std::vector<ActionItem> actions;
1889   auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
1890   MS_EXCEPTION_IF_NULL(graph_executor);
1891   const bool boost_infer = common::GetEnv("MS_DEV_BOOST_INFER") != "0" && graph_executor->graph_cell_count() == 0;
1892   if (!trace_flag) {
1893     if (boost_infer) {
1894       // Bootstrap for JIT.
1895       (void)actions.emplace_back(std::make_pair(kBootstrap, BootstrapAction));
1896     } else {
1897       // Parse the python ast to ANF graph
1898       (void)actions.emplace_back(std::make_pair(kParse, ParseAction));
1899 
1900       // Resolve the python func
1901       (void)actions.emplace_back(std::make_pair(kSymbolResolve, SymbolResolveAction));
1902 
1903       // Notice: Temporary solution, to be implemented using Python Rewriter in the future.
1904       // Set mixed Precision flag in subgraph.
1905       static bool enable_set_mixed_precision_flag = (common::GetCompileConfig("AMP_ENABLE_ALL_FG") == "1");
1906       if (enable_set_mixed_precision_flag) {
1907         (void)actions.emplace_back(std::make_pair(kSetMixedPrecisionFlag, SetMixedPrecisionAction));
1908       }
1909 
1910       auto parallel_context = parallel::ParallelContext::GetInstance();
1911       MS_EXCEPTION_IF_NULL(parallel_context);
1912       auto parallel_mode = parallel_context->parallel_mode();
1913       const bool is_parallel_mode =
1914         parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
1915       static const auto combine_like_graphs = (common::GetCompileConfig("COMBINE_LIKE_GRAPHS") == "1");
1916       static const auto force_disable_combine = (common::GetCompileConfig("COMBINE_LIKE_GRAPHS") == "0");
1917       if (!is_cluster_initialized && (!is_parallel_mode || combine_like_graphs) && !force_disable_combine) {
1918         (void)actions.emplace_back(std::make_pair(kCombineLikeGraphs, CombineLikeGraphs));
1919       }
1920 
1921       // Make the reusable cell to be the reusable function graph
1922       (void)actions.emplace_back(std::make_pair(kGraphReusing, GraphReusingAction));
1923 
1924       // Pre-Lift the func graphs.
1925       (void)actions.emplace_back(std::make_pair(kPreCConv, PreCConvAction));
1926     }
1927   }
1928   // Evaluate type and shape, and specialize.
1929   (void)actions.emplace_back(std::make_pair(kTypeInference, TypeInferenceAction));
1930 
1931   // Auto-monad for side-effects handling.
1932   (void)actions.emplace_back(std::make_pair(kAutoMonad, AutoMonadAction));
1933 
1934   if (boost_infer) {
1935     (void)actions.emplace_back(std::make_pair(kGraphReusing, GraphReusingAction));
1936   }
1937 
1938   // Do data structure simplifications and inline.
1939   (void)actions.emplace_back(std::make_pair(kInline, OptInlineAction));
1940 
1941   (void)actions.emplace_back(std::make_pair("parallel-infer-symbol", AutoParallelSymbolWithReNormalizeAction));
1942   // Do prepositive auto parallel.
1943   (void)actions.emplace_back(std::make_pair(kPreAutoParallel, AutoParallelAction));
1944   // insert virtual dataset
1945   (void)actions.emplace_back(std::make_pair("insert-virtual-dataset", ParallelVirtualDatasetAction));
1946   (void)actions.emplace_back(std::make_pair("parallel-infer-symbol-second", AutoParallelSymbolWithReNormalizeAction));
1947   // Do PipelineSplit action.
1948   (void)actions.emplace_back(std::make_pair(kPipelineSplit, PipelineSplitAction));
1949 
1950   return actions;
1951 }
1952 
EraseParseActions(const std::vector<ActionItem> & actions)1953 std::vector<ActionItem> EraseParseActions(const std::vector<ActionItem> &actions) {
1954   std::vector<ActionItem> filtered_actions;
1955   for (const auto &item : actions) {
1956     if (item.first != "parse") {
1957       (void)filtered_actions.emplace_back(item);
1958     }
1959   }
1960   return filtered_actions;
1961 }
1962 
VmPipeline(const ResourcePtr & resource,bool trace_flag,bool erase_parse)1963 std::vector<ActionItem> VmPipeline(const ResourcePtr &resource, bool trace_flag, bool erase_parse) {
1964   is_cluster_initialized = distributed::cluster::ClusterContext::instance()->initialized();
1965   std::vector<ActionItem> actions;
1966   // If enable compilation cache and the cache is read successfully, only do the backend actions.
1967   const std::string &phase = PhaseManager::GetInstance().phase();
1968   if (IsPhaseLoadFromMindIR(phase)) {
1969     actions = MindIRPipeline();
1970   } else if (!resource->EnableCompileCache() || resource->func_graph() == nullptr) {
1971     actions = CommonPipeline(trace_flag);
1972 
1973     // Optimize
1974     (void)actions.emplace_back(std::make_pair(kOptimize, VmOptimizeAction));
1975 
1976     (void)actions.emplace_back(std::make_pair(kPipelineParallelScheduler, PipelineSchedulerAction));
1977 
1978     (void)actions.emplace_back(std::make_pair(kAutoMonadReorder, OrderEnforceAction));
1979 
1980     // Eliminate forward cnode for grad graph
1981     (void)actions.emplace_back(std::make_pair(kGetJitBpropGraph, GetJitBpropGraph));
1982 
1983     // Rewriter(dict convert pyexecute) after jit bprop.
1984     (void)actions.emplace_back(std::make_pair(kRewriterAfterJitBprop, RewriterAfterOptAPassAfterJitBprop));
1985 
1986     // Eliminate the virtual mirror node
1987     (void)actions.emplace_back(std::make_pair(kEliminateSpecialOpNode, EliminateSpecialOpNode));
1988 
1989 #if defined(__linux__) && defined(WITH_BACKEND)
1990     if (!pipeline::IsPhaseExport(phase)) {
1991       (void)actions.emplace_back(std::make_pair(kDistributedSplit, DistributedSplitAction));
1992     }
1993     if (ps::PSContext::instance()->is_worker()) {
1994       if (distributed::cluster::ClusterContext::instance()->initialized()) {
1995         MS_LOG(INFO) << "This worker is initialized. No need to add worker action.";
1996       } else {
1997         std::string server_mode = ps::PSContext::instance()->server_mode();
1998       }
1999     }
2000 #endif
2001 
2002     // Mind Compiler finish.
2003     (void)actions.emplace_back(std::make_pair(kValidate, ValidateAction));
2004   }
2005 
2006   if (erase_parse) {
2007     actions = EraseParseActions(actions);
2008   }
2009 
2010   auto is_precompile_only = MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY);
2011   if (is_precompile_only) {
2012     MS_LOG(INFO) << "PrecompileOnly, stop run graph";
2013     return actions;
2014   }
2015 
2016   if (common::GetEnv(kSimulationLevel) == kSimulationLevelCompileGraph) {
2017     return actions;
2018   }
2019 
2020   auto ms_context = MsContext::GetInstance();
2021   MS_EXCEPTION_IF_NULL(ms_context);
2022 #ifndef WITH_BACKEND
2023   if (ms_context->backend_policy() != "ge") {
2024 #endif
2025     // Phase with "export" prefix need to skip backend compilation.
2026     if (pipeline::IsPhaseExport(phase)) {
2027       return actions;
2028     }
2029     // Compile the ANF graph
2030     (void)actions.emplace_back(std::make_pair(kTaskEmit, TaskEmitAction));
2031 
2032     // Execute the graph
2033     (void)actions.emplace_back(std::make_pair(kExecute, ExecuteAction));
2034 #ifndef WITH_BACKEND
2035   }
2036 #endif
2037   return actions;
2038 }
2039 
MindIRPipeline()2040 std::vector<ActionItem> MindIRPipeline() {
2041   auto context_ptr = MsContext::GetInstance();
2042   if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
2043     MS_LOG(EXCEPTION)
2044       << "The graph generated form MindIR is not support to execute in the PynativeMode, please convert "
2045          "to the GraphMode.";
2046   }
2047   std::vector<ActionItem> actions;
2048   // Set funcGraph loaded from MindIR to resource.
2049   (void)actions.emplace_back(std::make_pair(kLoadMindir, SetMindIRGraphAction));
2050   (void)actions.emplace_back(std::make_pair(kModifyMindirGraph, ModifyGraphGeneratedByMindIR));
2051   (void)actions.emplace_back(std::make_pair(kInferMindir, InferMindIR));
2052   (void)actions.emplace_back(std::make_pair(kValidate, ValidateAction));
2053   return actions;
2054 }
2055 }  // namespace pipeline
2056 }  // namespace mindspore
2057