1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/jit/ps/action.h"
18
19 #include <memory>
20 #include <map>
21 #include <utility>
22 #include <vector>
23 #include <set>
24 #include <string>
25 #include <algorithm>
26 #include <functional>
27
28 #include "mindspore/core/ops/sequence_ops.h"
29 #include "mindspore/core/ops/framework_ops.h"
30 #include "ir/anf.h"
31 #include "ir/func_graph_cloner.h"
32 #include "ir/param_info.h"
33 #include "ir/cell.h"
34 #include "include/common/utils/python_adapter.h"
35 #include "include/common/utils/anfalgo.h"
36 #include "include/common/utils/utils.h"
37 #include "include/common/utils/parallel_context.h"
38 #include "abstract/abstract_value.h"
39 #include "frontend/operator/composite/composite.h"
40 #include "frontend/parallel/step_auto_parallel.h"
41 #include "frontend/parallel/graph_util/graph_splitter.h"
42 #include "frontend/parallel/step_parallel_utils.h"
43 #include "frontend/parallel/shard/shard.h"
44 #include "pipeline/jit/ps/pipeline.h"
45 #include "pipeline/jit/ps/pass.h"
46 #include "pipeline/jit/ps/parse/parse_base.h"
47 #include "pipeline/jit/ps/parse/data_converter.h"
48 #include "pipeline/jit/ps/static_analysis/auto_monad.h"
49 #include "pipeline/jit/ps/static_analysis/order_enforce.h"
50 #include "pipeline/jit/ps/static_analysis/static_analysis.h"
51 #include "pipeline/jit/ps/static_analysis/async_eval_result.h"
52 #include "pipeline/jit/ps/static_analysis/program_specialize.h"
53 #include "pipeline/jit/ps/resource.h"
54 #include "pipeline/jit/ps/remove_value_node_dup.h"
55 #include "pipeline/jit/ps/event_message_print.h"
56 #include "pipeline/pynative/pynative_execute.h"
57 #include "frontend/optimizer/optimizer.h"
58 #include "frontend/optimizer/ad/grad.h"
59 #include "utils/ms_context.h"
60 #include "utils/ms_utils.h"
61 #include "utils/phase.h"
62 #include "utils/compile_config.h"
63 #include "backend/graph_compiler/transform.h"
64 #include "load_mindir/infer_mindir.h"
65 #include "include/backend/debug/data_dump/dump_json_parser.h"
66 #include "backend/common/graph_kernel/graph_kernel_flags.h"
67 #include "include/backend/debug/profiler/profiling.h"
68 #include "frontend/optimizer/fallback_rewriter.h"
69 #include "pipeline/jit/ps/load_mindir.h"
70 #if defined(__linux__) && defined(WITH_BACKEND)
71 #include "include/backend/distributed/cluster/cluster_context.h"
72 #include "include/backend/distributed/ps/ps_context.h"
73 #include "include/backend/distributed/ps/util.h"
74 #endif
75
76 namespace mindspore {
77 namespace pipeline {
78 namespace {
79 const auto kFirstInput = 1;
80 const auto kSecondInput = 2;
81 const auto kLazyInlineThershold = 64;
82
ExistControlFlow(const FuncGraphPtr & func_graph)83 bool ExistControlFlow(const FuncGraphPtr &func_graph) {
84 MS_EXCEPTION_IF_NULL(func_graph);
85 return !func_graph->func_graphs_used_total().empty();
86 }
87
EnableGradForScalar(const abstract::AbstractBasePtr & abs)88 bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
89 MS_EXCEPTION_IF_NULL(abs);
90 return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
91 abs->BuildType()->isa<Number>();
92 }
93
EnableSequenceBroaden(const abstract::AbstractBasePtr & abs)94 bool EnableSequenceBroaden(const abstract::AbstractBasePtr &abs) {
95 MS_EXCEPTION_IF_NULL(abs);
96 return abs->isa<abstract::AbstractSequence>() &&
97 abs->cast<abstract::AbstractSequencePtr>()->ContainsAllBroadenTensors();
98 }
99
ContainsAbstractFunction(const abstract::AbstractBasePtr & abs)100 bool ContainsAbstractFunction(const abstract::AbstractBasePtr &abs) {
101 MS_EXCEPTION_IF_NULL(abs);
102 if (abs->isa<abstract::AbstractFunction>()) {
103 return true;
104 }
105 if (abs->isa<abstract::AbstractSequence>()) {
106 const auto &abs_list = abs->cast<abstract::AbstractSequencePtr>()->elements();
107 return std::any_of(abs_list.cbegin(), abs_list.cend(),
108 [](const auto &elem) { return ContainsAbstractFunction(elem); });
109 }
110 if (abs->isa<abstract::AbstractDictionary>()) {
111 const auto &abs_pair_list = abs->cast<abstract::AbstractDictionaryPtr>()->elements();
112 return std::any_of(abs_pair_list.cbegin(), abs_pair_list.cend(),
113 [](const auto &pair) { return ContainsAbstractFunction(pair.second); });
114 }
115 return false;
116 }
117
UpdateFuncGraphParameter(const FuncGraphPtr & func_graph,const std::vector<ValuePtr> & arguments)118 void UpdateFuncGraphParameter(const FuncGraphPtr &func_graph, const std::vector<ValuePtr> &arguments) {
119 MS_EXCEPTION_IF_NULL(func_graph);
120 std::vector<AnfNodePtr> new_paras;
121 for (size_t i = 0; i < func_graph->parameters().size(); ++i) {
122 const auto ¶m = func_graph->parameters()[i];
123 auto param_node = param->cast<ParameterPtr>();
124 MS_EXCEPTION_IF_NULL(param_node);
125 if (param_node->has_default()) {
126 new_paras.push_back(param_node);
127 continue;
128 }
129
130 // Handle the Parameter from input arguments.
131 if (i < arguments.size()) {
132 auto param_value = dyn_cast<tensor::MetaTensor>(arguments[i]);
133 if (param_value != nullptr && param_value->is_parameter()) {
134 param_node->set_default_param(param_value);
135 }
136 }
137
138 AbstractBasePtr param_abs = param_node->abstract();
139 MS_EXCEPTION_IF_NULL(param_abs);
140 if ((param_abs->BuildValue() == kValueAny && !ContainsAbstractFunction(param_abs)) ||
141 EnableGradForScalar(param_abs) || EnableSequenceBroaden(param_abs)) {
142 new_paras.push_back(param_node);
143 } else {
144 MS_LOG(INFO) << "Remove the " << i << "th parameter, since it's passed a constant argument.";
145 }
146 }
147 func_graph->set_parameters(new_paras);
148 }
149
150 // Exist ScalarAdd ScalarSub etc OPS which will backoff to CPU
IsNeedBackoffGraph(const FuncGraphPtr & func_graph)151 bool IsNeedBackoffGraph(const FuncGraphPtr &func_graph) {
152 MS_EXCEPTION_IF_NULL(func_graph);
153 std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return(), SuccDeeperSimple);
154 return std::any_of(node_list.begin(), node_list.end(),
155 [](const AnfNodePtr &node) { return common::AnfAlgo::IsNodeMutableScalar(node); });
156 }
157
158 // Disable mindRT in the heterogeneous scenario + dynamic_shape scenario.
DisableMindRT(const ResourcePtr & resource)159 void DisableMindRT(const ResourcePtr &resource) {
160 MS_EXCEPTION_IF_NULL(resource);
161 auto context_ptr = MsContext::GetInstance();
162 MS_EXCEPTION_IF_NULL(context_ptr);
163 if (!context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
164 return;
165 }
166 #if defined(__linux__) && defined(WITH_BACKEND)
167 if (ps::PSContext::instance()->cache_enable()) {
168 return;
169 }
170 #endif
171 }
172
TaskEmitActionForMindRT(const ResourcePtr & resource)173 void TaskEmitActionForMindRT(const ResourcePtr &resource) {
174 MS_EXCEPTION_IF_NULL(resource);
175 // Get the mindRT backend.
176 auto bc_ptr = resource->GetBackend();
177 // In pyexecute kernel, the input data would be stored in user data which is a python object, this converter
178 // is used to convert user data to device ptr in device address.
179 compile::set_pydata_converter([](const py::object &obj, ValuePtr *value) { return parse::ConvertData(obj, value); });
180 auto mindrt_bc_ptr = std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr);
181 MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
182 MS_EXCEPTION_IF_NULL(resource->func_graph());
183 auto actor_info = mindrt_bc_ptr->CompileGraphs(resource->func_graph());
184 resource->SetResult(kOutput, actor_info);
185 resource->SetResult(kActorInfo, actor_info);
186 }
187
ExecuteActionForMindRT(const ResourcePtr & resource)188 void ExecuteActionForMindRT(const ResourcePtr &resource) {
189 MS_EXCEPTION_IF_NULL(resource);
190 const auto actor_info = resource->GetResult(kOutput).cast<compile::ActorInfo>();
191 // Get the mindRT backend.
192 auto bc_ptr = resource->GetBackend();
193 auto mindrt_bc_ptr = (std::dynamic_pointer_cast<compile::MindRTBackend>(bc_ptr)).get();
194 MS_EXCEPTION_IF_NULL(mindrt_bc_ptr);
195
196 // Construct the graph run function ptr.
197 compile::VmEvalFuncPtr run =
198 std::make_shared<compile::VmEvalFunc>([mindrt_bc_ptr, actor_info](const VectorRef &args) -> BaseRef {
199 MS_LOG(DEBUG) << "Execute args size " << args.size();
200 VectorRef outputs;
201 mindrt_bc_ptr->RunGraph(actor_info, args, &outputs);
202 MS_LOG(DEBUG) << "out size " << outputs.size();
203 if (outputs.empty()) {
204 return VectorRef();
205 } else {
206 return outputs[0];
207 }
208 });
209 resource->SetResult(kOutput, run);
210 }
211
ConstructGraphForEval(const ValuePtr & func,const abstract::AbstractBasePtrList & args_abs)212 FuncGraphPtr ConstructGraphForEval(const ValuePtr &func, const abstract::AbstractBasePtrList &args_abs) {
213 auto func_abs = func->ToAbstract();
214 if (!func_abs->isa<abstract::AbstractFunction>()) {
215 MS_LOG(EXCEPTION) << "The value : " << func->ToString() << " is not a callable object.";
216 }
217 // construct a function graph.
218 auto infer_graph = std::make_shared<FuncGraph>();
219 std::vector<AnfNodePtr> inputs = {std::make_shared<ValueNode>(func)};
220 std::transform(args_abs.begin(), args_abs.end(), std::back_inserter(inputs),
221 [infer_graph](const AbstractBasePtr &) -> AnfNodePtr { return infer_graph->add_parameter(); });
222 auto infer_node = infer_graph->NewCNode(inputs);
223 infer_graph->set_return(infer_node);
224 return infer_graph;
225 }
226 } // namespace
227 using CompileGraphs = compile::CompileGraphs;
228 using abstract::AnalysisResult;
229 using mindspore::abstract::AnalysisContextPtr;
230
231 // Whether this process in a MindSpore cluster.
232 static bool is_cluster_initialized = false;
233
IsDynamicShapeGraph(const FuncGraphPtr & func_graph)234 bool IsDynamicShapeGraph(const FuncGraphPtr &func_graph) {
235 MS_EXCEPTION_IF_NULL(func_graph);
236 std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return(), SuccDeeperSimple);
237 return std::any_of(node_list.begin(), node_list.end(), [](const AnfNodePtr &node) {
238 if (common::AnfAlgo::IsCallNode(node)) {
239 return false;
240 }
241 return common::AnfAlgo::IsDynamicShape(node);
242 });
243 }
244
AbstractAnalyze(const abstract::AnalysisEnginePtr & engine,const FuncGraphPtr & func_graph,const abstract::AbstractBasePtrList & args_abs,bool is_load_resoure,bool clear)245 abstract::AnalysisResult AbstractAnalyze(const abstract::AnalysisEnginePtr &engine, const FuncGraphPtr &func_graph,
246 const abstract::AbstractBasePtrList &args_abs, bool is_load_resoure,
247 bool clear) {
248 MS_LOG(DEBUG) << "AbstractAnalyze start";
249 py::gil_scoped_acquire gil;
250 MS_EXCEPTION_IF_NULL(engine);
251 if (clear || is_load_resoure) {
252 auto manager = engine->func_graph_manager();
253 MS_EXCEPTION_IF_NULL(manager);
254 engine->Clear();
255 static const auto enable_eliminate_unused_element = (common::GetCompileConfig("ENABLE_DDE") != "0");
256 for (auto &node : manager->all_nodes()) {
257 MS_EXCEPTION_IF_NULL(node);
258 // Handle previous inferred value for CNode if is loaded from MindIR
259 // If the primitive is not defined in front end, keep the inferred value loaded from MindIR.
260 if (is_load_resoure) {
261 auto primitive = GetCNodePrimitive(node);
262 if (primitive != nullptr) {
263 auto is_load = primitive->GetAttr("is_load");
264 if (abstract::GetPrimEvaluator(primitive, engine) == nullptr && is_load != nullptr &&
265 GetValue<bool>(is_load)) {
266 MS_LOG(INFO) << "The primitive is not defined in front end. Primitive: " << primitive->ToString();
267 continue;
268 }
269 }
270 if (!clear && node->isa<Parameter>()) {
271 continue;
272 }
273 }
274
275 const AbstractBasePtr &prev_inferred = node->abstract();
276 // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction.
277 if (!node->isa<ValueNode>() || (prev_inferred != nullptr && prev_inferred->isa<abstract::AbstractFunction>())) {
278 // Reset tuple/list abstract use flags.
279 if (enable_eliminate_unused_element && prev_inferred != nullptr &&
280 prev_inferred->isa<abstract::AbstractSequence>()) {
281 SetSequenceNodeElementsUseFlags(node, nullptr);
282 }
283 node->set_abstract(nullptr);
284 MS_LOG(DEBUG) << "Abstract of node " << node->DebugString() << " is set to nullptr";
285 }
286 }
287 }
288 auto res = engine->Run(func_graph, args_abs);
289 MS_LOG(INFO) << "function call depth: " << abstract::FunctionCallDepth()
290 << ", simulate call depth: " << abstract::StackFrameDepth();
291 MS_LOG(DEBUG) << "AbstractAnalyze end";
292 return res;
293 }
294
AbstractAnalyze(const ValuePtr & func,const abstract::AbstractBasePtrList & args_abs,bool clear)295 abstract::AnalysisResult AbstractAnalyze(const ValuePtr &func, const abstract::AbstractBasePtrList &args_abs,
296 bool clear) {
297 auto infer_graph = func->isa<FuncGraph>() ? func->cast<FuncGraphPtr>() : ConstructGraphForEval(func, args_abs);
298 auto manager = Manage(infer_graph, true);
299 auto engine = std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager);
300 return AbstractAnalyze(engine, infer_graph, args_abs, false, clear);
301 }
302
AbstractAnalyzeWithResourceClean(const ValuePtr & func,const abstract::AbstractBasePtrList & args_abs)303 abstract::AnalysisResult AbstractAnalyzeWithResourceClean(const ValuePtr &func,
304 const abstract::AbstractBasePtrList &args_abs) {
305 auto infer_graph = func->isa<FuncGraph>() ? func->cast<FuncGraphPtr>() : ConstructGraphForEval(func, args_abs);
306
307 ResourcePtr resource = std::make_shared<Resource>();
308 resource->set_func_graph(infer_graph);
309
310 auto engine = resource->engine();
311 auto res = AbstractAnalyze(engine, infer_graph, args_abs, false, true);
312
313 GraphExecutorPy::GetInstance()->CleanCompileRes(resource);
314 return res;
315 }
316
ProgramSpecialize(const abstract::AnalysisEnginePtr & engine,const FuncGraphPtr & func_graph,const abstract::AnalysisContextPtr & context)317 FuncGraphPtr ProgramSpecialize(const abstract::AnalysisEnginePtr &engine, const FuncGraphPtr &func_graph,
318 const abstract::AnalysisContextPtr &context) {
319 MS_EXCEPTION_IF_NULL(engine);
320 MS_LOG(DEBUG) << "ProgramSpecialize start";
321 abstract::ProgramSpecializer specializer(engine);
322 FuncGraphPtr result = specializer.Run(func_graph, context);
323 auto manager = engine->func_graph_manager();
324 MS_EXCEPTION_IF_NULL(manager);
325 manager->KeepRoots({result});
326 specializer.SpecializeCNodeInput0FuncGraph();
327 MS_LOG(DEBUG) << "ProgramSpecialize end";
328 return result;
329 }
330
Renormalize(const ResourcePtr & resource,const FuncGraphPtr & func_graph,const abstract::AbstractBasePtrList & args_abs)331 FuncGraphPtr Renormalize(const ResourcePtr &resource, const FuncGraphPtr &func_graph,
332 const abstract::AbstractBasePtrList &args_abs) {
333 MS_EXCEPTION_IF_NULL(resource);
334 MS_LOG(DEBUG) << "Renormalize start";
335 auto engine = resource->engine();
336
337 abstract::AnalysisResult result;
338 {
339 MsProfileStatGuard stat_guard("renormalize.infer");
340 result = AbstractAnalyze(engine, func_graph, args_abs, resource->is_load(), true);
341 }
342 FuncGraphPtr res;
343 {
344 MsProfileStatGuard stat_guard("renormalize.specialize");
345 res = ProgramSpecialize(engine, func_graph, result.context);
346 resource->set_func_graph(res);
347 }
348
349 MS_LOG(DEBUG) << "Renormalize end";
350 return res;
351 }
352
Renormalize(const ValuePtr & func,const abstract::AbstractBasePtrList & args_abs)353 FuncGraphPtr Renormalize(const ValuePtr &func, const abstract::AbstractBasePtrList &args_abs) {
354 auto func_abs = func->ToAbstract();
355 if (!func_abs->isa<abstract::AbstractFunction>()) {
356 MS_LOG(EXCEPTION) << "The value: " << func->ToString() << " is not a callable object.";
357 }
358 auto func_graph = ConstructGraphForEval(func, args_abs);
359 auto manager = Manage(func_graph, true);
360 auto engine = std::make_shared<abstract::AnalysisEngine>(abstract::GetPrimEvaluatorConstructors(), manager);
361
362 abstract::AnalysisResult result;
363 {
364 MsProfileStatGuard stat_guard("renormalize.infer");
365 result = AbstractAnalyze(engine, func_graph, args_abs, false);
366 }
367 FuncGraphPtr res;
368 {
369 MsProfileStatGuard stat_guard("renormalize.specialize");
370 res = ProgramSpecialize(engine, func_graph, result.context);
371 }
372
373 return res;
374 }
375
SetMindIRLoadFlag(const ResourcePtr & resource)376 void SetMindIRLoadFlag(const ResourcePtr &resource) {
377 MS_EXCEPTION_IF_NULL(resource);
378 auto manager = resource->manager();
379 MS_EXCEPTION_IF_NULL(manager);
380 FuncGraphPtr loaded_graph = nullptr;
381 size_t loaded_graph_num = 0;
382 auto all_graphs = manager->func_graphs();
383 for (auto &graph : all_graphs) {
384 MS_EXCEPTION_IF_NULL(graph);
385 if (graph->has_attr("is_load")) {
386 loaded_graph = graph;
387 loaded_graph_num += 1;
388 resource->set_is_load(true);
389 return;
390 }
391 }
392 }
393
394 namespace {
395 // Get entry function/class.method name.
GetFunctionName(const py::object & input)396 std::string GetFunctionName(const py::object &input) {
397 // Get Cell.construct() or @jit function name.
398 std::string function_name;
399 if (py::hasattr(input, parse::PYTHON_PARSE_METHOD)) {
400 // The class type string format is like: <class 'x.x.xxx'>
401 std::string class_type_name = py::cast<std::string>(py::str(input.get_type()));
402 constexpr auto class_type_prefix_len = 8; // <class '
403 constexpr auto class_type_suffix_len = 2; // '>
404 const auto class_type_len = class_type_name.length();
405 // Exclude class prefix and suffix.
406 auto class_name =
407 class_type_name.substr(class_type_prefix_len, class_type_len - class_type_prefix_len - class_type_suffix_len);
408 auto method_name = py::cast<std::string>(input.attr(parse::PYTHON_PARSE_METHOD));
409 function_name = class_name + '.' + method_name;
410 } else if (py::hasattr(input, "__jit_function__") && py::hasattr(input, "__name__")) {
411 // Get @jit decorated function name.
412 auto jit_name = py::cast<std::string>(input.attr("__name__"));
413 function_name = jit_name;
414 } else {
415 MS_EXCEPTION(NotSupportError) << "Entry Python object for JIT is invalid.\ninput: " << py::str(input);
416 }
417 MS_LOG(DEBUG) << "function_name: " << function_name;
418 return function_name;
419 }
420
421 // Update top graph name.
UpdateTopGraphDebugInfo(const FuncGraphPtr & func_graph,const py::object & input)422 void UpdateTopGraphDebugInfo(const FuncGraphPtr &func_graph, const py::object &input) {
423 auto function_name = GetFunctionName(input);
424 // Normalize the name.
425 std::replace(function_name.begin(), function_name.end(), '.', '_');
426 std::replace(function_name.begin(), function_name.end(), '<', '_');
427 std::replace(function_name.begin(), function_name.end(), '>', '_');
428
429 MS_EXCEPTION_IF_NULL(func_graph);
430 MS_EXCEPTION_IF_NULL(func_graph->debug_info());
431 func_graph->debug_info()->set_name(function_name);
432 }
433
434 struct FuncArgSpec {
435 AnfNodePtrList args_;
436 ParameterPtr varargs_{nullptr};
437 AnfNodePtrList kwonlyargs_;
438 ParameterPtr varkw_{nullptr};
439 };
440
MakeDefaultValue(const py::dict & defaults,const std::string & arg_name,std::vector<std::string> * namelist_for_default_value,std::vector<AnfNodePtr> * default_values)441 void MakeDefaultValue(const py::dict &defaults, const std::string &arg_name,
442 std::vector<std::string> *namelist_for_default_value, std::vector<AnfNodePtr> *default_values) {
443 (void)namelist_for_default_value->emplace_back(arg_name);
444 if (defaults.contains(arg_name)) {
445 AnfNodePtr arg_node = NewValueNode(parse::data_converter::PyDataToValue(defaults[py::str(arg_name)]));
446 (void)default_values->emplace_back(arg_node);
447 } else {
448 (void)default_values->emplace_back(NewValueNode(kNull));
449 }
450 }
451
CheckIgnoreSelfParam(const py::object & input)452 bool CheckIgnoreSelfParam(const py::object &input) {
453 auto input_type = parse::data_converter::GetObjType(input);
454 if (input_type == parse::ResolveType::RESOLVE_TYPE_CLASS_INSTANCE) {
455 return true;
456 }
457 if (input_type == parse::ResolveType::RESOLVE_TYPE_METHOD) {
458 py::object method_object = python_adapter::GetPyObjAttr(input, parse::PYTHON_GET_METHOD_SELF_CLASS);
459 if (!py::isinstance<py::none>(method_object)) {
460 return true;
461 }
462 }
463 return false;
464 }
465
GetFuncArgSpec(const FuncGraphPtr & func_graph,const py::object & input)466 FuncArgSpec GetFuncArgSpec(const FuncGraphPtr &func_graph, const py::object &input) {
467 auto func = input;
468 if (py::hasattr(input, parse::PYTHON_PARSE_METHOD)) {
469 auto func_name = py::cast<std::string>(input.attr(parse::PYTHON_PARSE_METHOD));
470 func = input.attr(func_name.c_str());
471 }
472 py::tuple obj_tuple =
473 python_adapter::CallPyFn(parse::PYTHON_MOD_PARSE_MODULE, "get_arg_spec_and_default_values", func);
474 auto full_arg_spec = obj_tuple[0];
475 py::dict defaults = obj_tuple[1];
476 std::vector<std::string> namelist_for_default_value;
477 std::vector<AnfNodePtr> default_values;
478 FuncArgSpec arg_spec;
479 bool ignore_self_param = CheckIgnoreSelfParam(input);
480 if (py::hasattr(full_arg_spec, "args")) {
481 for (const auto &arg : full_arg_spec.attr("args")) {
482 auto arg_name = py::cast<std::string>(arg);
483 if (arg_name == "self" && ignore_self_param) {
484 continue;
485 }
486 auto para = func_graph->add_parameter();
487 para->set_is_top_graph_param(true);
488 para->set_name(arg_name);
489 (void)arg_spec.args_.emplace_back(para);
490 MakeDefaultValue(defaults, arg_name, &namelist_for_default_value, &default_values);
491 }
492 }
493
494 if (py::hasattr(full_arg_spec, "varargs")) {
495 auto varargs = full_arg_spec.attr("varargs");
496 if (!py::isinstance<py::none>(varargs)) {
497 arg_spec.varargs_ = func_graph->add_parameter();
498 arg_spec.varargs_->set_is_top_graph_param(true);
499 auto arg_name = py::cast<std::string>(varargs);
500 arg_spec.varargs_->set_name(arg_name);
501 func_graph->set_has_vararg(true);
502 MakeDefaultValue(defaults, arg_name, &namelist_for_default_value, &default_values);
503 }
504 }
505
506 if (py::hasattr(full_arg_spec, "kwonlyargs")) {
507 for (const auto &arg : full_arg_spec.attr("kwonlyargs")) {
508 auto para = func_graph->add_parameter();
509 para->set_is_top_graph_param(true);
510 auto arg_name = py::cast<std::string>(arg);
511 para->set_name(arg_name);
512 (void)arg_spec.kwonlyargs_.emplace_back(para);
513 MakeDefaultValue(defaults, arg_name, &namelist_for_default_value, &default_values);
514 }
515 func_graph->set_kwonlyargs_count(SizeToInt(arg_spec.kwonlyargs_.size()));
516 }
517
518 if (py::hasattr(full_arg_spec, "varkw")) {
519 auto varkw = full_arg_spec.attr("varkw");
520 if (!py::isinstance<py::none>(varkw)) {
521 arg_spec.varkw_ = func_graph->add_parameter();
522 arg_spec.varkw_->set_is_top_graph_param(true);
523 auto arg_name = py::cast<std::string>(varkw);
524 arg_spec.varkw_->set_name(arg_name);
525 func_graph->set_has_kwarg(true);
526 MakeDefaultValue(defaults, arg_name, &namelist_for_default_value, &default_values);
527 }
528 }
529 func_graph->SetDefaultValues(namelist_for_default_value, default_values);
530 return arg_spec;
531 }
532
BuildTopGraph(const FuncGraphPtr & func_graph,const py::object & input,const abstract::AbstractBasePtrList & args_abs)533 void BuildTopGraph(const FuncGraphPtr &func_graph, const py::object &input,
534 const abstract::AbstractBasePtrList &args_abs) {
535 // Make Resolve for user top graph 'input'.
536 auto function_name = GetFunctionName(input);
537 parse::NameSpacePtr name_space =
538 std::make_shared<parse::NameSpace>(parse::RESOLVE_NAMESPACE_NAME_ENTRY, py::str(function_name), input);
539 parse::SymbolPtr symbol = std::make_shared<parse::Symbol>(function_name);
540 MS_LOG(DEBUG) << "name_space: " << name_space->ToString() << ", symbol: " << symbol->ToString();
541 ValueNodePtr module_node = NewValueNode(name_space);
542 ValueNodePtr symbol_node = NewValueNode(symbol);
543
544 bool contains_value_any = false;
545 ValuePtrList args_value_list;
546 (void)std::transform(args_abs.cbegin(), args_abs.cend(), std::back_inserter(args_value_list),
547 [&contains_value_any](const AbstractBasePtr &abs) {
548 auto res = abs->BuildValue();
549 if (res->isa<ValueAny>()) {
550 contains_value_any = true;
551 }
552 return res;
553 });
554 CNodePtr resolve_node;
555 if (contains_value_any) {
556 resolve_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node});
557 } else {
558 ValueNodePtr args_node = NewValueNode<ValuePtrList>(args_value_list);
559 resolve_node = func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node, args_node});
560 }
561
562 auto arg_spec = GetFuncArgSpec(func_graph, input);
563 bool need_unpack = false;
564 if (func_graph->has_vararg() || func_graph->has_kwarg() || func_graph->kwonlyargs_count() > 0) {
565 need_unpack = true;
566 }
567 // Call user top graph in top graph.
568 AnfNodePtrList inputs;
569 if (!need_unpack) {
570 (void)inputs.emplace_back(resolve_node);
571 std::copy(func_graph->parameters().cbegin(), func_graph->parameters().cend(), std::back_inserter(inputs));
572 } else {
573 (void)inputs.emplace_back(NewValueNode(std::make_shared<prim::UnpackCall>(parse::NAMED_METAGRAPH_UNPACKCALL)));
574 (void)inputs.emplace_back(resolve_node);
575 if (!arg_spec.args_.empty()) {
576 AnfNodePtrList args_inputs = {NewValueNode(prim::kPrimMakeTuple)};
577 std::copy(arg_spec.args_.cbegin(), arg_spec.args_.cend(), std::back_inserter(args_inputs));
578 (void)inputs.emplace_back(func_graph->NewCNodeInOrder(args_inputs));
579 }
580 if (arg_spec.varargs_ != nullptr) {
581 (void)inputs.emplace_back(arg_spec.varargs_);
582 }
583 if (arg_spec.varkw_ != nullptr) {
584 (void)inputs.emplace_back(arg_spec.varkw_);
585 }
586 if (!arg_spec.kwonlyargs_.empty()) {
587 AnfNodePtrList key_inputs = {NewValueNode(prim::kPrimMakeTuple)};
588 AnfNodePtrList value_inputs = {NewValueNode(prim::kPrimMakeTuple)};
589 for (const auto &kwonlyarg : arg_spec.kwonlyargs_) {
590 (void)key_inputs.emplace_back(NewValueNode(kwonlyarg->cast<ParameterPtr>()->name()));
591 (void)value_inputs.emplace_back(kwonlyarg);
592 }
593 auto make_dict =
594 func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimMakeDict), func_graph->NewCNodeInOrder(key_inputs),
595 func_graph->NewCNodeInOrder(value_inputs)});
596 (void)inputs.emplace_back(make_dict);
597 }
598 }
599 auto output = func_graph->NewCNodeInOrder(inputs);
600 constexpr auto recursive_level = 2;
601 MS_LOG(DEBUG) << "output: " << output->DebugString(recursive_level);
602 func_graph->set_output(output);
603 }
604 } // namespace
605
BootstrapAction(const ResourcePtr & resource)606 bool BootstrapAction(const ResourcePtr &resource) {
607 MS_EXCEPTION_IF_NULL(resource);
608 TraceManager::OpenParserDebugInfoFlag();
609 if (!resource->source_input()) {
610 MS_LOG(INTERNAL_EXCEPTION) << "Bootstrap error";
611 }
612 py::object input = resource->source_input();
613 parse::Parser::InitParserEnvironment(input);
614 parse::Parser::EnableDeferResolve(false);
615 py::module path = py::module::import("os.path");
616 auto dir = path.attr("dirname")(py::globals()["__file__"]).cast<std::string>();
617 python_adapter::set_python_env_flag(true);
618 python_adapter::SetPythonPath(dir);
619
620 // Create fake top graph firstly.
621 auto top_graph = std::make_shared<FuncGraph>();
622 MS_EXCEPTION_IF_NULL(top_graph);
623 auto is_top_graph = (py::hasattr(input, parse::PYTHON_PARSE_METHOD) || py::hasattr(input, "__jit_function__"));
624 if (!is_top_graph) {
625 MS_EXCEPTION(NotSupportError) << "Not supported Python object for JIT entry.\ninput: " << py::str(input);
626 }
627 UpdateTopGraphDebugInfo(top_graph, input);
628 // Call the user top graph with its arguments.
629 BuildTopGraph(top_graph, input, resource->args_abs());
630 // Set the top graph.
631 parse::Parser::UpdateTopFuncGraph(top_graph);
632 resource->set_func_graph(top_graph);
633 FuncGraphManagerPtr manager = resource->manager();
634 MS_EXCEPTION_IF_NULL(manager);
635 manager->AddFuncGraph(top_graph);
636 return true;
637 }
638
ParseAction(const ResourcePtr & resource)639 bool ParseAction(const ResourcePtr &resource) {
640 MS_EXCEPTION_IF_NULL(resource);
641 TraceManager::OpenParserDebugInfoFlag();
642 if (!resource->source_input()) {
643 MS_LOG(INTERNAL_EXCEPTION) << "Parse error";
644 }
645
646 py::object input = resource->source_input();
647 parse::Parser::InitParserEnvironment(input);
648 parse::Parser::EnableDeferResolve(false);
649 py::module path = py::module::import("os.path");
650 auto dir = path.attr("dirname")(py::globals()["__file__"]).cast<std::string>();
651
652 python_adapter::set_python_env_flag(true);
653 python_adapter::SetPythonPath(dir);
654
655 ValuePtrList args_value_list;
656 (void)std::transform(resource->args_abs().begin(), resource->args_abs().end(), std::back_inserter(args_value_list),
657 [](const AbstractBasePtr &abs) { return abs->BuildValue(); });
658 parse::DataConverter data_converter(args_value_list, true);
659 auto converted_ret = data_converter.ConvertData(input);
660 if (converted_ret == nullptr) {
661 MS_LOG(INTERNAL_EXCEPTION) << "Attribute convert error with type: " << std::string(py::str(input));
662 }
663
664 auto top_graph = converted_ret->cast<FuncGraphPtr>();
665 if (top_graph == nullptr) {
666 MS_LOG(INTERNAL_EXCEPTION) << "Object to parse " << std::string(py::str(input)) << " is not function or cell.";
667 }
668 if (py::hasattr(input, parse::PYTHON_PARSE_METHOD) || py::hasattr(input, "__jit_function__")) {
669 (void)std::for_each(top_graph->parameters().begin(), top_graph->parameters().end(),
670 [](const AnfNodePtr ¶m) { param->cast<ParameterPtr>()->set_is_top_graph_param(true); });
671 }
672 parse::Parser::UpdateTopFuncGraph(top_graph);
673 resource->set_func_graph(top_graph);
674 FuncGraphManagerPtr manager = resource->manager();
675 MS_EXCEPTION_IF_NULL(manager);
676 manager->AddFuncGraph(top_graph);
677
678 parse::Parser::EnableDeferResolve(true);
679 return true;
680 }
681
682 // obj_map's graphs have the same construct, these graphs can be optimized to one graph.
683 // This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}->
684 // graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx}
685 // all obj_map's graph shared base_graph
CombineLikeGraphs(const ResourcePtr & resource)686 bool CombineLikeGraphs(const ResourcePtr &resource) {
687 MS_EXCEPTION_IF_NULL(resource);
688 auto &obj_map = parse::data_converter::GetObjGraphs();
689 for (auto it = obj_map.rbegin(); it != obj_map.rend(); ++it) {
690 if (it->first.find("lazy_inline") != it->first.npos) {
691 continue;
692 }
693 auto &graphs = it->second;
694 MS_LOG(DEBUG) << "Start combine like graph:" << it->first << ", size:" << graphs.size();
695 auto fg = graphs[0];
696 FuncGraphVector func_graphs = {fg};
697 Cloner cloner(func_graphs, false, false, true, std::make_shared<TraceCopy>(),
698 std::make_shared<TraceCombileLikeGraphs>());
699 cloner.Run();
700 auto cloned_fg_iter = cloner.cloned_func_graphs().find(fg);
701 if (cloned_fg_iter == cloner.cloned_func_graphs().end()) {
702 MS_LOG(INTERNAL_EXCEPTION) << "Clone func graph failed! " << fg->ToString();
703 }
704 auto base_graph = cloned_fg_iter->second;
705 MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString();
706
707 if (fg->parameter_obj_nodes().empty() || graphs.size() <= 1 || fg->has_flag(FUNC_GRAPH_OUTPUT_NO_RECOMPUTE) ||
708 fg->stage() != -1) {
709 continue;
710 }
711 auto &cloned_nodes = cloner.cloned_nodes();
712 for (auto &fv : fg->parameter_obj_nodes()) {
713 TraceGuard guard(std::make_shared<TraceCombileLikeGraphs>(fg->output()->debug_info()));
714 auto param = base_graph->add_parameter();
715 MS_EXCEPTION_IF_NULL(resource->manager());
716 auto &node_users = resource->manager()->node_users()[fv];
717 for (auto &n : node_users) {
718 // If the user is not in this graph, no need to change.
719 auto iter = cloned_nodes.find(n.first);
720 if (iter == cloned_nodes.end()) {
721 continue;
722 }
723 auto repl_n = iter->second->cast<CNodePtr>();
724 MS_EXCEPTION_IF_NULL(repl_n);
725 repl_n->set_input(IntToSize(n.second), param);
726 }
727 }
728 MS_LOG(DEBUG) << "Fg0 parameter_obj_nodes size :" << fg->parameter_obj_nodes().size();
729
730 for (auto &g : graphs) {
731 TraceGuard guard(std::make_shared<TraceCopy>(fg->output()->debug_info()));
732 auto &fvs = g->parameter_obj_nodes();
733 std::vector<AnfNodePtr> new_node_inputs;
734 new_node_inputs.push_back(NewValueNode(base_graph));
735 for (auto &p : g->parameters()) {
736 AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p);
737 new_node_inputs.push_back(para_after_cast);
738 }
739 (void)new_node_inputs.insert(new_node_inputs.end(), fvs.cbegin(), fvs.cend());
740 AnfNodePtr out = g->NewCNodeBefore(g->get_return(), new_node_inputs);
741 g->set_output(out);
742 const int recursive_level = 4;
743 MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(recursive_level);
744 }
745 MS_LOG(DEBUG) << "End combine graph:" << it->first;
746 }
747 return true;
748 }
749
750 namespace {
751 // Get all the trainable parameters of the reusable cell.
GenerateTopGraphParams(const FuncGraphPtr & fg,std::vector<AnfNodePtr> * params,const FuncGraphPtr & top_func_graph)752 void GenerateTopGraphParams(const FuncGraphPtr &fg, std::vector<AnfNodePtr> *params,
753 const FuncGraphPtr &top_func_graph) {
754 MS_LOG(DEBUG) << "enter GenerateTopGraphParams: " << fg->ToString();
755 auto obj_value = fg->python_obj();
756 MS_EXCEPTION_IF_NULL(obj_value);
757 auto wrapper = dyn_cast_ptr<parse::PyObjectWrapper>(obj_value);
758 MS_EXCEPTION_IF_NULL(wrapper);
759 auto obj = wrapper->obj();
760 auto trainable_parameters = py::getattr(obj, "parameters_and_names", py::none())();
761 for (auto tr : trainable_parameters) {
762 auto item = py::cast<py::tuple>(tr);
763 auto value = item[1];
764 auto par_name = item[0].cast<std::string>();
765 auto parameter_name = py::getattr(value, "name", py::str(par_name)).cast<std::string>();
766 auto exist_fv = top_func_graph->GetParameterByName(parameter_name);
767 if (exist_fv) {
768 params->push_back(exist_fv);
769 MS_LOG(DEBUG) << "exist: " << parameter_name;
770 } else {
771 auto fv = top_func_graph->AddFvParameter(parameter_name, parse::GetParameterValue(value));
772 auto context = parallel::ParallelContext::GetInstance();
773 if (context != nullptr && fv->has_default()) {
774 auto fv_abs = pipeline::GetDefaultValueAbstract(fv);
775 context->ParallelParameterContextRestoreShape(top_func_graph, fv, fv_abs);
776 fv->set_abstract(fv_abs);
777 }
778 MS_LOG(DEBUG) << "New: " << parameter_name;
779 params->push_back(fv);
780 }
781 }
782 MS_LOG(DEBUG) << "finish GenerateTopGraphParams: " << fg->ToString();
783 }
784
UpdateCellFuncGraph(const FuncGraphPtr & func_graph,const FuncGraphPtr & reusing_graph,const FuncGraphPtr & top_func_graph)785 void UpdateCellFuncGraph(const FuncGraphPtr &func_graph, const FuncGraphPtr &reusing_graph,
786 const FuncGraphPtr &top_func_graph) {
787 std::vector<AnfNodePtr> new_node_inputs;
788 new_node_inputs.push_back(NewValueNode(reusing_graph));
789 std::vector<AnfNodePtr> fvs;
790 GenerateTopGraphParams(func_graph, &fvs, top_func_graph);
791 (void)new_node_inputs.insert(new_node_inputs.end(), fvs.rbegin(), fvs.rend());
792 auto params = func_graph->parameters();
793 (void)new_node_inputs.insert(new_node_inputs.end(), params.begin(), params.end());
794 AnfNodePtr out = func_graph->NewCNodeInOrder(new_node_inputs);
795 out->set_abstract(func_graph->output()->abstract());
796 func_graph->set_output(out);
797 }
798
GeneralizeReusingGraph(const FuncGraphPtr & func_graph,const FuncGraphPtr & top_func_graph)799 void GeneralizeReusingGraph(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_func_graph) {
800 FuncGraphPtr fg = func_graph;
801 FuncGraphVector func_graphs = {fg};
802 Cloner cloner(func_graphs, false, false, true, std::make_shared<TraceCopy>(), std::make_shared<TraceGraphReusing>());
803 cloner.Run();
804 auto cloned_fg_iter = cloner.cloned_func_graphs().find(fg);
805 if (cloned_fg_iter == cloner.cloned_func_graphs().end()) {
806 MS_LOG(INTERNAL_EXCEPTION) << "Clone func graph failed! " << fg->ToString();
807 }
808 auto reusing_graph = cloned_fg_iter->second;
809 auto &cloned_nodes = cloner.cloned_nodes();
810 auto manager = fg->manager();
811 std::vector<AnfNodePtr> fv_params;
812 GenerateTopGraphParams(fg, &fv_params, top_func_graph);
813 for (auto &fv : fv_params) {
814 auto param = reusing_graph->InsertFrontParameter();
815 const auto &top_param = fv->cast<ParameterPtr>();
816 std::string name = "CR_" + top_param->name();
817 param->debug_info()->set_name(name);
818 param->set_name(name);
819 param->set_abstract(top_param->abstract());
820 auto &node_users = manager->node_users()[fv];
821 for (auto &n : node_users) {
822 auto iter = cloned_nodes.find(n.first);
823 if (iter == cloned_nodes.end()) {
824 continue;
825 }
826 auto repl_n = iter->second->cast<CNodePtr>();
827 MS_EXCEPTION_IF_NULL(repl_n);
828 repl_n->set_input(IntToSize(n.second), param);
829 }
830 }
831
832 if (func_graph->has_attr(FUNC_GRAPH_FLAG_NO_INLINE)) {
833 reusing_graph->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, func_graph->has_flag(FUNC_GRAPH_FLAG_NO_INLINE));
834 } else {
835 reusing_graph->set_flag(FUNC_GRAPH_FLAG_NO_INLINE, true);
836 reusing_graph->set_flag(FUNC_GRAPH_FLAG_CELL_REUSE, true);
837 }
838
839 // Update call nodes
840 auto no_inline_flag = reusing_graph->has_flag(FUNC_GRAPH_FLAG_NO_INLINE);
841 auto cnodes_index = fg->func_graph_cnodes_index();
842 for (auto &cnode_index : cnodes_index) {
843 MS_EXCEPTION_IF_NULL(cnode_index.first);
844 auto old_cnode = cnode_index.first->first->cast<CNodePtr>();
845 MS_EXCEPTION_IF_NULL(old_cnode);
846 auto cell_func_graph = old_cnode->func_graph();
847 MS_EXCEPTION_IF_NULL(cell_func_graph);
848 UpdateCellFuncGraph(cell_func_graph, reusing_graph, top_func_graph);
849
850 // optimize FuncGraph::scope() performance
851 cell_func_graph->set_flag(FUNC_GRAPH_FLAG_NO_CHILD_GRAPH, no_inline_flag);
852 }
853 }
854
SetCalledSubGraphMixedPrecisionFlag(const FuncGraphPtr & func_graph)855 void SetCalledSubGraphMixedPrecisionFlag(const FuncGraphPtr &func_graph) {
856 FuncGraphPtr fp16_mixed_precision_fg;
857 FuncGraphPtr fp32_mixed_precision_fg;
858 FuncGraphPtr bf16_mixed_precision_fg;
859 // Find the first subgraph which has mixed precision flag.
860 for (auto &item : func_graph->func_graphs_used()) {
861 if (item.first->has_flag(GRAPH_FLAG_MIX_PRECISION_FP16)) {
862 fp16_mixed_precision_fg = item.first;
863 }
864 if (item.first->has_flag(GRAPH_FLAG_MIX_PRECISION_FP32)) {
865 fp32_mixed_precision_fg = item.first;
866 }
867 if (item.first->has_flag(GRAPH_FLAG_MIX_PRECISION_BF16)) {
868 bf16_mixed_precision_fg = item.first;
869 }
870 if ((fp32_mixed_precision_fg != nullptr) || (fp16_mixed_precision_fg != nullptr) ||
871 (bf16_mixed_precision_fg != nullptr)) {
872 break;
873 }
874 }
875
876 // Add mixed precision flag to new subgraph which call subgraph in set.
877 if (fp16_mixed_precision_fg != nullptr) {
878 for (auto sub_fg : fp16_mixed_precision_fg->func_graphs_used_total()) {
879 sub_fg->set_flag(GRAPH_FLAG_MIX_PRECISION_FP16, true);
880 }
881 }
882 if (fp32_mixed_precision_fg != nullptr) {
883 for (auto sub_fg : fp32_mixed_precision_fg->func_graphs_used_total()) {
884 sub_fg->set_flag(GRAPH_FLAG_MIX_PRECISION_FP32, true);
885 }
886 }
887 if (bf16_mixed_precision_fg != nullptr) {
888 for (auto sub_fg : bf16_mixed_precision_fg->func_graphs_used_total()) {
889 sub_fg->set_flag(GRAPH_FLAG_MIX_PRECISION_BF16, true);
890 }
891 }
892 }
893 } // namespace
894
895 // Make the reusable cell to be the reusable function graph.
GraphReusingAction(const ResourcePtr & resource)896 bool GraphReusingAction(const ResourcePtr &resource) {
897 MS_EXCEPTION_IF_NULL(resource);
898 bool cell_reused = false;
899 auto func_graph = resource->func_graph();
900 std::multimap<int, FuncGraphPtr> order_fgs;
901 for (auto &fg : func_graph->func_graphs_used_total()) {
902 auto order_value = fg->get_attr(FUNC_GRAPH_FLAG_CELL_LAZY_INLINE_ORDER);
903 if (order_value == nullptr) {
904 continue;
905 }
906 fg->erase_flag(FUNC_GRAPH_FLAG_CELL_LAZY_INLINE_ORDER);
907 order_fgs.insert(std::make_pair(GetValue<int>(order_value), fg));
908 }
909 for (auto it = order_fgs.rbegin(); it != order_fgs.rend(); ++it) {
910 MS_LOG(INFO) << "Lazy_inline graph: " << it->second->ToString() << " , order: " << it->first;
911 GeneralizeReusingGraph(it->second, func_graph);
912 cell_reused = true;
913 }
914 if (!cell_reused) {
915 return true;
916 }
917
918 auto context = MsContext::GetInstance();
919 MS_EXCEPTION_IF_NULL(context);
920 const bool enable_ge = context->backend_policy() == "ge";
921 const bool force_no_inline = common::IsDisableRuntimeConfig(common::kRuntimeInline);
922 context->SetCellReuseLevel(CellReuseLevel::kNoCellReuse);
923
924 MS_LOG(INFO) << "Cell reuse(@lazy_inline) actually takes effect.";
925 auto cell_reuse_level =
926 (enable_ge && !context->IsKByKExecutorMode()) ? CellReuseLevel::kNoInline : CellReuseLevel::kLazyInline;
927 if (force_no_inline) {
928 cell_reuse_level = CellReuseLevel::kNoInline;
929 }
930 context->SetCellReuseLevel(cell_reuse_level);
931
932 return true;
933 }
934
935 // Used for excluding the func graphs in VMap.
UsedByVmap(const FuncGraphPtr & func_graph)936 bool UsedByVmap(const FuncGraphPtr &func_graph) {
937 const auto &cnodes_index = func_graph->func_graph_cnodes_index();
938 if (cnodes_index.empty()) {
939 return false;
940 }
941 const auto matcher = [&func_graph](const std::pair<const CNodeIndexPairPtr, int64_t> &cnode_index) {
942 const auto &cnode = cnode_index.first->first;
943 const auto &vmap_meta = GetCNodeValueWithoutDoSignature(cnode);
944 if (vmap_meta != nullptr && vmap_meta->isa<prim::VmapOperation>()) {
945 MS_LOG(DEBUG) << "Found VMap CNode: " << cnode->DebugString();
946 return true;
947 }
948 // The func graph is used in MakeTuple or UnpackGraph.
949 const auto user_matcher = [](const FuncGraphPtr &func_graph, const AnfNodePtr &cnode) {
950 auto manager = func_graph->manager();
951 MS_EXCEPTION_IF_NULL(manager);
952 auto &users = manager->node_users()[cnode];
953 for (const auto &user : users) {
954 const auto &user_vmap_meta = GetCNodeValueWithoutDoSignature(user.first);
955 if (user_vmap_meta != nullptr && user_vmap_meta->isa<prim::VmapOperation>()) {
956 MS_LOG(DEBUG) << "Found VMap CNode: " << user.first->DebugString();
957 return true;
958 }
959 }
960 return false;
961 };
962 const auto unpack_graph_prim = GetCNodePrimitive(cnode);
963 if (unpack_graph_prim != nullptr && unpack_graph_prim->isa<prim::UnpackGraphPrimitive>()) {
964 return user_matcher(func_graph, cnode);
965 }
966 if (IsPrimitiveCNode(cnode, prim::kPrimMakeTuple)) {
967 return user_matcher(func_graph, cnode);
968 }
969 // Deal with F.vmap(fn, ...) in construct().
970 // Not check fn passed from nested func graph calls.
971 if (cnode_index.first->second == 1) {
972 const auto vmap_func = GetCNodeFuncGraph(cnode);
973 if (vmap_func == nullptr) {
974 return false;
975 }
976 auto first_param = vmap_func->parameters()[0];
977 return user_matcher(func_graph, first_param);
978 }
979 return false;
980 };
981 return std::any_of(cnodes_index.cbegin(), cnodes_index.cend(), matcher);
982 }
983
PreCConvAction(const ResourcePtr & resource)984 bool PreCConvAction(const ResourcePtr &resource) {
985 static const bool enable_pre_lift = (common::GetCompileConfig("PRE_LIFT") == "1");
986 if (!enable_pre_lift) {
987 return true;
988 }
989 MS_EXCEPTION_IF_NULL(resource);
990 MS_EXCEPTION_IF_NULL(resource->func_graph());
991 FuncGraphPtr func_graph = resource->func_graph();
992 FuncGraphPtr new_fg = LiftingClone(func_graph, false, UsedByVmap);
993 resource->set_func_graph(new_fg);
994 return GradPartialTransformPass(resource);
995 }
996
SymbolResolveAction(const ResourcePtr & resource)997 bool SymbolResolveAction(const ResourcePtr &resource) {
998 MS_EXCEPTION_IF_NULL(resource);
999 if (resource->manager() == nullptr) {
1000 MS_LOG(INTERNAL_EXCEPTION) << "SymbolResolve error, manager is null";
1001 }
1002 auto func_graph = resource->func_graph();
1003 if (func_graph == nullptr) {
1004 MS_LOG(INTERNAL_EXCEPTION) << "SymbolResolve error, graph is null";
1005 }
1006 bool ret = parse::ResolveFuncGraph(func_graph, resource);
1007 // Remove unused nodes in cnode order list,
1008 // and check isolated side-effect nodes.
1009 if (func_graph != nullptr) {
1010 func_graph->EraseUnusedNodeInOrder();
1011 for (auto fg : func_graph->func_graphs_used_total()) {
1012 if (fg != nullptr) {
1013 fg->EraseUnusedNodeInOrder();
1014 }
1015 }
1016 }
1017 return ret;
1018 }
1019
SetMixedPrecisionAction(const ResourcePtr & resource)1020 bool SetMixedPrecisionAction(const ResourcePtr &resource) {
1021 if (resource->manager() == nullptr) {
1022 MS_LOG(EXCEPTION) << "SetMixedPrecisionAction error, manager is null";
1023 }
1024 auto func_graph = resource->func_graph();
1025 if (func_graph == nullptr) {
1026 MS_LOG(EXCEPTION) << "SetMixedPrecisionAction error, graph is null";
1027 }
1028 SetCalledSubGraphMixedPrecisionFlag(func_graph);
1029 MS_LOG(DEBUG) << "Finish set mixed Precision flag in subgraph. ";
1030 return true;
1031 }
1032
AutoMonadAction(const ResourcePtr & resource)1033 bool AutoMonadAction(const ResourcePtr &resource) {
1034 MS_EXCEPTION_IF_NULL(resource);
1035 if (resource->manager() == nullptr) {
1036 MS_LOG(INTERNAL_EXCEPTION) << "Auto-Monad failed, manager is null";
1037 }
1038 auto func_graph = resource->func_graph();
1039 if (func_graph == nullptr) {
1040 MS_LOG(INTERNAL_EXCEPTION) << "Auto-Monad failed, graph is null";
1041 }
1042 (void)pipeline::AutoMonad(func_graph);
1043 return true;
1044 }
1045
OrderEnforceAction(const ResourcePtr & resource)1046 bool OrderEnforceAction(const ResourcePtr &resource) {
1047 MS_EXCEPTION_IF_NULL(resource);
1048 if (resource->manager() == nullptr) {
1049 MS_LOG(INTERNAL_EXCEPTION) << "Order-Enforce error, manager is null";
1050 }
1051 auto func_graph = resource->func_graph();
1052 if (func_graph == nullptr) {
1053 MS_LOG(INTERNAL_EXCEPTION) << "Order-Enforce error, graph is null";
1054 }
1055 pipeline::OrderEnforce(func_graph);
1056 return true;
1057 }
1058
1059 // Get abstract of the default value in the given parameter.
GetDefaultValueAbstract(const ParameterPtr & param)1060 AbstractBasePtr GetDefaultValueAbstract(const ParameterPtr ¶m) {
1061 auto value = param->default_param();
1062 MS_EXCEPTION_IF_NULL(value);
1063 auto value_abs = value->ToAbstract();
1064 MS_EXCEPTION_IF_NULL(value_abs);
1065 if (value_abs->isa<abstract::AbstractMapTensor>()) {
1066 // Return AbstractMapTensor for map parameter.
1067 return value_abs;
1068 }
1069 // Make an AbstractRefTensor for the tensor value.
1070 auto abs_tensor = value_abs->cast<abstract::AbstractTensorPtr>();
1071 MS_EXCEPTION_IF_NULL(abs_tensor);
1072 auto ref_key = std::make_shared<RefKey>(param->name());
1073 return std::make_shared<abstract::AbstractRefTensor>(abs_tensor, ref_key);
1074 }
1075
1076 namespace {
GetArgsAbs(const ResourcePtr & resource)1077 abstract::AbstractBasePtrList GetArgsAbs(const ResourcePtr &resource) {
1078 FuncGraphPtr func_graph = resource->func_graph();
1079 abstract::AbstractBasePtrList args_abs = resource->args_abs();
1080
1081 // Parallel checking.
1082 auto context = parallel::ParallelContext::GetInstance();
1083 MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
1084
1085 // Handle the Parameter from FV inputs.
1086 for (const auto ¶m : func_graph->parameters()) {
1087 auto param_node = std::static_pointer_cast<Parameter>(param);
1088 MS_EXCEPTION_IF_NULL(param_node);
1089 if (param_node->has_default()) {
1090 auto param_abs = GetDefaultValueAbstract(param_node);
1091 context->ParallelParameterContextRestoreShape(func_graph, param_node, param_abs);
1092 (void)args_abs.emplace_back(param_abs);
1093 }
1094 }
1095 return args_abs;
1096 }
1097 } // namespace
1098
TypeInferenceAction(const ResourcePtr & resource)1099 bool TypeInferenceAction(const ResourcePtr &resource) {
1100 EventMessage::PrintCompileStatusMessage("Start performing static analysis and type inference.");
1101 MS_EXCEPTION_IF_NULL(resource);
1102 if (resource->func_graph() == nullptr) {
1103 MS_LOG(INTERNAL_EXCEPTION) << "AbstractSpecialize error";
1104 }
1105 SetMindIRLoadFlag(resource);
1106 // Abstract analyze
1107 auto engine = resource->engine();
1108 MS_EXCEPTION_IF_NULL(engine);
1109
1110 // Check isolated side-effect nodes.
1111 engine->set_check_side_effect(true);
1112 // Analyze
1113 (void)profiler::CollectHostInfo(kCompiler, kTypeInference, kAbstractAnalyze, 0, 0, 0);
1114 AnalysisResult result;
1115 {
1116 MsProfileStatGuard stat_guard("type_inference.infer");
1117 result = AbstractAnalyze(resource->engine(), resource->func_graph(), GetArgsAbs(resource), resource->is_load());
1118 }
1119 (void)profiler::CollectHostInfo(kCompiler, kTypeInference, kAbstractAnalyze, 0, 0, 1);
1120 // Specialize
1121 (void)profiler::CollectHostInfo(kCompiler, kTypeInference, kProgramSpecialize, 0, 0, 0);
1122 FuncGraphPtr new_fg;
1123 {
1124 MsProfileStatGuard stat_guard("type_inference.specialize");
1125 new_fg = ProgramSpecialize(resource->engine(), result.context->func_graph(), result.context);
1126 }
1127 (void)profiler::CollectHostInfo(kCompiler, kTypeInference, kProgramSpecialize, 0, 0, 1);
1128 // Update the top func graph with the specialized graph.
1129 parse::Parser::UpdateTopFuncGraph(new_fg);
1130 resource->set_func_graph(new_fg);
1131 engine->set_check_side_effect(false);
1132
1133 // Remove unused nodes in cnode order list, this is prepared for auto-monad.
1134 if (new_fg) {
1135 new_fg->EraseUnusedNodeInOrder();
1136 for (auto fg : new_fg->func_graphs_used_total()) {
1137 if (fg) {
1138 fg->EraseUnusedNodeInOrder();
1139 }
1140 }
1141 }
1142
1143 UpdateFuncGraphParameter(new_fg, resource->arguments());
1144 MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true);
1145 return true;
1146 }
1147
OptimizeAction(const ResourcePtr & resource,const std::vector<PassItem> & passes)1148 bool OptimizeAction(const ResourcePtr &resource, const std::vector<PassItem> &passes) {
1149 MS_EXCEPTION_IF_NULL(resource);
1150 size_t counter = 0;
1151 for (auto &pass : passes) {
1152 ProcessStatus::GetInstance().RecordStart(pass.first);
1153 (void)profiler::CollectHostInfo(kCompiler, kOptimize, pass.first, 0, 0, 0);
1154 auto profile_context = MsProfile::GetProfile()->Step(pass.first);
1155 auto pass_func = [&pass, &resource, &counter]() {
1156 MS_LOG(DEBUG) << "Pass " << pass.first << " start ...";
1157 auto result = pass.second(resource);
1158 if (!result) {
1159 MS_LOG(INTERNAL_EXCEPTION) << "Pass running to end, failed in pass:" << pass.first;
1160 }
1161 #ifdef ENABLE_DUMP_IR
1162 auto context = MsContext::GetInstance();
1163 MS_EXCEPTION_IF_NULL(context);
1164 if (context->CanDump(kIntroductory) && resource->func_graph() != nullptr) {
1165 auto fg_name = "opt_pass_" + std::to_string(counter) + "_" + pass.first;
1166 auto func_graph = resource->func_graph();
1167 MS_EXCEPTION_IF_NULL(func_graph);
1168 static const auto switch_order = (common::GetEnv("MS_DEV_SAVE_GRAPHS_SORT_MODE") == "1");
1169 if (switch_order) {
1170 ExportIR(fg_name + ".ir", func_graph);
1171 } else {
1172 DumpIR(fg_name + ".ir", func_graph);
1173 }
1174 if (context->CanDump(kFully)) {
1175 draw::Draw(fg_name + ".dot", func_graph);
1176 }
1177 MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
1178 }
1179 #endif
1180 counter++;
1181 MS_LOG(DEBUG) << "Pass " << pass.first << " end.";
1182 };
1183 ProfileExecute(profile_context, pass_func);
1184 (void)profiler::CollectHostInfo(kCompiler, kOptimize, pass.first, 0, 0, 1);
1185 ProcessStatus::GetInstance().RecordEnd();
1186 }
1187
1188 return true;
1189 }
1190
OptInlineAction(const ResourcePtr & resource)1191 bool OptInlineAction(const ResourcePtr &resource) {
1192 if (parallel::ParallelContext::GetInstance()->parallel_mode() == "semi_auto_parallel" ||
1193 parallel::ParallelContext::GetInstance()->parallel_mode() == "auto_parallel") {
1194 return OptimizeAction(resource, kInlinePasses);
1195 }
1196 return true;
1197 }
1198
VmOptimizeAction(const ResourcePtr & resource)1199 bool VmOptimizeAction(const ResourcePtr &resource) {
1200 EventMessage::PrintCompileStatusMessage("Start performing graph optimization.");
1201 #if defined(__linux__) && defined(WITH_BACKEND)
1202 if (ps::PSContext::instance()->is_ps_mode()) {
1203 (void)kVmPasses.emplace_back(PassItem("server_communication_op_fusion", [](const ResourcePtr &res) -> bool {
1204 MS_EXCEPTION_IF_NULL(res);
1205 return ps::Util::FuseServerCommOps(res->func_graph());
1206 }));
1207 }
1208 #endif
1209 auto ret = OptimizeAction(resource, kVmPasses);
1210 TraceManager::CloseParserDebugInfoFlag();
1211 return ret;
1212 }
1213
IsCtrlSink()1214 static bool IsCtrlSink() {
1215 auto ms_ctx = MsContext::GetInstance();
1216 if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) {
1217 return false;
1218 }
1219
1220 std::string device_target = ms_ctx->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1221 if (device_target != kAscendDevice) {
1222 return false;
1223 }
1224
1225 if (!ms_ctx->get_param<bool>(MS_CTX_ENABLE_TASK_SINK)) {
1226 return false;
1227 }
1228
1229 return ms_ctx->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK);
1230 }
1231
CheckGraphOutputConstOrParameter(const FuncGraphPtr & func_graph)1232 bool CheckGraphOutputConstOrParameter(const FuncGraphPtr &func_graph) {
1233 if (func_graph != nullptr) {
1234 AnfNodePtr output = func_graph->output();
1235 if (output != nullptr && (output->isa<ValueNode>() || output->isa<Parameter>())) {
1236 return true;
1237 }
1238 }
1239 return false;
1240 }
1241
GetJitBpropGraph(const ResourcePtr & resource)1242 bool GetJitBpropGraph(const ResourcePtr &resource) {
1243 // This function only works in Pynative mode. The func_graph is decorated with 'jit'.
1244 if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
1245 return true;
1246 }
1247 return pynative::PyNativeExecutor::GetInstance()->grad_executor()->jit()->GetJitGradGraph(resource);
1248 }
1249
RewriterAfterOptAPassAfterJitBprop(const ResourcePtr & resource)1250 bool RewriterAfterOptAPassAfterJitBprop(const ResourcePtr &resource) {
1251 // This function is only used to convert unsupported syntax into PyExecute nodes through Fallback,
1252 // when the forward graph is decorated with 'jit', and is derivative in pynative mode.
1253 auto context = MsContext::GetInstance();
1254 MS_EXCEPTION_IF_NULL(context);
1255 if (context->not_convert_jit()) {
1256 context->set_not_convert_jit(false);
1257 MS_EXCEPTION_IF_NULL(resource);
1258 FuncGraphPtr func_graph = resource->func_graph();
1259 MS_EXCEPTION_IF_NULL(func_graph);
1260 (void)mindspore::opt::RewriterAfterOptA(func_graph, resource);
1261 UpdateArgsSpec(func_graph, resource);
1262 }
1263 return true;
1264 }
1265
EliminateSpecialOpNode(const ResourcePtr & resource)1266 bool EliminateSpecialOpNode(const ResourcePtr &resource) {
1267 MS_EXCEPTION_IF_NULL(resource);
1268 if (resource->manager() == nullptr) {
1269 MS_LOG(INTERNAL_EXCEPTION) << "PynativeElimOpt error, manager is null.";
1270 }
1271 if (resource->func_graph() == nullptr) {
1272 MS_LOG(INTERNAL_EXCEPTION) << "PynativeElimOpt error, graph is null.";
1273 }
1274 return EliminateSpecialOpOptPass(resource);
1275 }
1276
HasIncorporateCall(const std::vector<AnfNodePtr> & all_nodes)1277 bool HasIncorporateCall(const std::vector<AnfNodePtr> &all_nodes) {
1278 for (const auto &node : all_nodes) {
1279 if (node == nullptr || !node->isa<CNode>()) {
1280 continue;
1281 }
1282 auto cnode = node->cast<CNodePtr>();
1283 if (IsPrimitiveCNode(cnode, prim::kPrimPartial)) {
1284 auto partial_function = cnode->input(kPartialGraphIndex);
1285 if (!IsValueNode<FuncGraph>(partial_function)) {
1286 MS_LOG(INFO) << "Partial has indirect call: " << cnode->DebugString();
1287 return true;
1288 }
1289 continue;
1290 }
1291 if (IsPrimitiveCNode(cnode, prim::kPrimSwitch)) {
1292 const auto &switch_inputs = cnode->inputs();
1293 if (std::any_of(switch_inputs.begin() + kSwitchTrueBranchIndex, switch_inputs.end(), [](const AnfNodePtr &input) {
1294 return !IsPrimitiveCNode(input, prim::kPrimPartial) && !IsValueNode<FuncGraph>(input);
1295 })) {
1296 MS_LOG(INFO) << "Switch has indirect call: " << cnode->DebugString();
1297 return true;
1298 }
1299 continue;
1300 }
1301 if (IsPrimitiveCNode(cnode, prim::kPrimSwitchLayer)) {
1302 auto make_tuple = cnode->input(kSwitchLayerBranchesIndex);
1303 if (!IsPrimitiveCNode(make_tuple, prim::kPrimMakeTuple)) {
1304 MS_LOG(INTERNAL_EXCEPTION) << "SwitchLayer input2 should be make_tuple, but got: " << make_tuple->DebugString();
1305 }
1306 const auto &make_tuple_inputs = make_tuple->cast<CNodePtr>()->inputs();
1307 if (std::any_of(make_tuple_inputs.begin() + 1, make_tuple_inputs.end(), [](const AnfNodePtr &input) {
1308 return !IsPrimitiveCNode(input, prim::kPrimPartial) && !IsValueNode<FuncGraph>(input);
1309 })) {
1310 MS_LOG(INFO) << "SwitchLayer has indirect call: " << cnode->DebugString();
1311 return true;
1312 }
1313 continue;
1314 }
1315 if (common::AnfAlgo::HasIncorporateCallNode(cnode)) {
1316 return true;
1317 }
1318 }
1319 return false;
1320 }
1321
ExistTarget(const std::vector<AnfNodePtr> & all_nodes,const std::string & target)1322 bool ExistTarget(const std::vector<AnfNodePtr> &all_nodes, const std::string &target) {
1323 for (const auto &node : all_nodes) {
1324 if (node == nullptr || !node->isa<CNode>()) {
1325 continue;
1326 }
1327 if (GetCNodeTarget(node) == target) {
1328 return true;
1329 }
1330 }
1331 return false;
1332 }
1333
1334 // If the return value of subgraph is Ref in control flow scenarios, should run graph mode with kernelbykernel.
ExistSwitchRef(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & all_nodes)1335 bool ExistSwitchRef(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &all_nodes) {
1336 // %1 = switch(cond, func1, func2)
1337 // %2 = %1() if the abstract of the node is AbstractRefTensor or Tuple/List(AbstractRefTensor, ...), return true.
1338 auto manager = func_graph->manager();
1339 MS_EXCEPTION_IF_NULL(manager);
1340 auto &node_users = manager->node_users();
1341 auto context_ptr = MsContext::GetInstance();
1342 MS_EXCEPTION_IF_NULL(context_ptr);
1343 std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1344 for (const auto &node : all_nodes) {
1345 if (!IsPrimitiveCNode(node, prim::kPrimSwitch)) {
1346 continue;
1347 }
1348 auto iter = node_users.find(node);
1349 if (iter != node_users.end()) {
1350 auto &users = iter->second;
1351 for (auto &user : users) {
1352 auto &user_node = user.first;
1353 if (common::AnfAlgo::HasAbstractRef(user_node) || common::AnfAlgo::SequenceHasAbstractRef(user_node)) {
1354 if (device_target == kAscendDevice) {
1355 MS_LOG(WARNING) << "On the Ascend platform, if you read-only access to the parameter, "
1356 << "you can take the value of the parameter, so that the system can do more optimization. "
1357 << "For example, change 'return param' to 'return param.value()'\n"
1358 << "Please check your code:" << trace::GetDebugInfoStr(user_node->debug_info());
1359 }
1360 return true;
1361 }
1362 }
1363 }
1364 }
1365 return false;
1366 }
1367
SetModeForControlFlow(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & all_nodes,bool pynative_mode,compile::Backend * backend_ptr)1368 bool SetModeForControlFlow(const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &all_nodes, bool pynative_mode,
1369 compile::Backend *backend_ptr) {
1370 auto context_ptr = MsContext::GetInstance();
1371 MS_EXCEPTION_IF_NULL(context_ptr);
1372 MS_EXCEPTION_IF_NULL(func_graph);
1373 MS_EXCEPTION_IF_NULL(backend_ptr);
1374 auto set_ctx = [&context_ptr, &backend_ptr](bool task_sink, bool is_multi_graph_sink, bool enable_loop_sink) {
1375 context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, task_sink);
1376 context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, is_multi_graph_sink);
1377 context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, enable_loop_sink);
1378 backend_ptr->set_is_multi_graph_sink(is_multi_graph_sink);
1379 };
1380 // GRAPH | Closure\ENV\While scenario : KernelByKernel path in MindRT.
1381 auto graphs = func_graph->func_graphs_used_total();
1382 (void)graphs.insert(func_graph);
1383 bool exist_control_flow = ExistControlFlow(func_graph);
1384 bool exist_func = exist_control_flow && HasIncorporateCall(all_nodes);
1385 if (exist_func) {
1386 if (!pynative_mode) {
1387 MS_LOG(INFO) << "Run graph mode with sub graph sink because graph exist control flow and incorporate call.";
1388 set_ctx(true, false, false);
1389 } else {
1390 MS_LOG(INFO) << "Run graph mode with kernel by kernel because graph exist control flow and incorporate call.";
1391 set_ctx(false, false, false);
1392 }
1393 return false;
1394 }
1395 bool exist_while =
1396 std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
1397 MS_LOG(INFO) << func_graph->ToString() << " exist_while: " << exist_while;
1398 if (exist_while || ExistSwitchRef(func_graph, all_nodes)) {
1399 if (!pynative_mode) {
1400 MS_LOG(INFO) << "Run graph mode with sub graph sink because graph exist while or switch ref.";
1401 set_ctx(true, false, false);
1402 } else {
1403 MS_LOG(INFO) << "Run graph mode with kernel by kernel because graph exist while or switch ref.";
1404 set_ctx(false, false, false);
1405 }
1406 return false;
1407 }
1408 // Multiple device targets scenario.
1409 if (func_graph->exist_multi_target()) {
1410 // Heterogeneous scenario + ControlFlow : KernelByKernel path in MindRT.
1411 if (exist_control_flow && pynative_mode) {
1412 MS_LOG(INFO) << "Run graph mode with kernel by kernel because graph exist multi device target and control flow.";
1413 set_ctx(false, false, false);
1414 return false;
1415 }
1416 // GRAPH | Heterogeneous scenario : No control flow, subgraph sink path in MindRT.
1417 MS_LOG(INFO) << "Run graph mode with subgraph sink because graph exist multi device target.";
1418 set_ctx(true, false, false);
1419 return false;
1420 }
1421 return true;
1422 }
1423
IsCellReuse(const AnfNodePtr & input)1424 bool IsCellReuse(const AnfNodePtr &input) {
1425 if (IsValueNode<FuncGraph>(input)) {
1426 auto fg = GetValueNode<FuncGraphPtr>(input);
1427 MS_EXCEPTION_IF_NULL(fg);
1428 if (fg->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
1429 return true;
1430 }
1431 }
1432 return false;
1433 }
1434
ProcessCanNotInline(const FuncGraphPtr & func_graph,const std::shared_ptr<MsContext> & context_ptr)1435 void ProcessCanNotInline(const FuncGraphPtr &func_graph, const std::shared_ptr<MsContext> &context_ptr) {
1436 auto graphs = func_graph->func_graphs_used_total();
1437 (void)graphs.insert(func_graph);
1438 bool exist_while =
1439 std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
1440 if (exist_while && context_ptr->CellReuseLevel() == CellReuseLevel::kLazyInline) {
1441 MS_LOG(INFO) << "Set no inline because graph has while.";
1442 context_ptr->SetCellReuseLevel(CellReuseLevel::kNoInline);
1443 }
1444
1445 auto cant_inline_cell_reuse = [](const FuncGraphPtr &fg) -> bool {
1446 if (!fg->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
1447 return false;
1448 }
1449 MS_LOG(INFO) << "Cell reuse graph: " << fg->ToString();
1450 // cell reuse func graph has switch
1451 if (!fg->switch_nodes().empty()) {
1452 MS_LOG(INFO) << "Set no inline because cell reuse graph has switch, " << fg->ToString();
1453 return true;
1454 }
1455 // cell reuse sub graph has switch or cell reuse
1456 for (auto &sub_graph : fg->func_graphs_used_total()) {
1457 if (!sub_graph->switch_nodes().empty() || sub_graph->has_flag(FUNC_GRAPH_FLAG_CELL_REUSE)) {
1458 MS_LOG(INFO) << "Set no inline because cell reuse sub graph has switch or nested cell reuse, "
1459 << sub_graph->ToString();
1460 return true;
1461 }
1462 }
1463 return false;
1464 };
1465 if (std::any_of(graphs.cbegin(), graphs.cend(), cant_inline_cell_reuse)) {
1466 MS_LOG(INFO) << "Set no inline because cell reuse graph has switch or nested cell reuse.";
1467 context_ptr->SetCellReuseLevel(CellReuseLevel::kNoInline);
1468 }
1469 if (!common::IsEnableRuntimeConfig(common::kRuntimeInline)) {
1470 const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
1471 size_t micro_num = 0;
1472 for (auto &node : all_nodes) {
1473 if (!node->isa<CNode>()) {
1474 continue;
1475 }
1476 auto cnode = node->cast<CNodePtr>();
1477 if (IsCellReuse(cnode->input(0))) {
1478 micro_num++;
1479 }
1480 }
1481 auto parallel_context = parallel::ParallelContext::GetInstance();
1482 MS_EXCEPTION_IF_NULL(parallel_context);
1483 auto stages = parallel_context->pipeline_stage_split_num();
1484 if (stages <= 1) {
1485 return;
1486 }
1487 MS_LOG(INFO) << "Cell reuse micro num: " << micro_num;
1488 if (micro_num > kLazyInlineThershold) {
1489 MS_LOG(INFO) << "Set no inline because cell reuse micro num is greater than " << kLazyInlineThershold
1490 << ", micro num: " << micro_num;
1491 context_ptr->SetCellReuseLevel(CellReuseLevel::kNoInline);
1492 }
1493 }
1494 }
1495
SetRunMode(const FuncGraphPtr & func_graph,compile::Backend * backend_ptr,std::string * kbk_reason)1496 void SetRunMode(const FuncGraphPtr &func_graph, compile::Backend *backend_ptr, std::string *kbk_reason) {
1497 auto context_ptr = MsContext::GetInstance();
1498 MS_EXCEPTION_IF_NULL(context_ptr);
1499 MS_EXCEPTION_IF_NULL(func_graph);
1500 MS_EXCEPTION_IF_NULL(backend_ptr);
1501 auto set_ctx = [&context_ptr, &backend_ptr](bool task_sink, bool is_multi_graph_sink, bool enable_loop_sink) {
1502 context_ptr->set_param<bool>(MS_CTX_ENABLE_TASK_SINK, task_sink);
1503 context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, is_multi_graph_sink);
1504 context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, enable_loop_sink);
1505 backend_ptr->set_is_multi_graph_sink(is_multi_graph_sink);
1506 };
1507 ProcessCanNotInline(func_graph, context_ptr);
1508 auto jit_level = pipeline::GetJitLevel();
1509 func_graph->set_attr(kAttrJitLevel, MakeValue<std::string>(jit_level));
1510 auto jit_config = PhaseManager::GetInstance().jit_config();
1511 jit_config[kAttrJitLevel] = context_ptr->GetJitLevel();
1512 graphkernel::GraphKernelFlags::SaveJitConfig(jit_config);
1513
1514 const bool pynative_mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
1515 const auto &device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1516 if (pynative_mode && device_target != kAscendDevice) {
1517 return;
1518 }
1519 const auto &all_nodes = TopoSort(func_graph->return_node(), SuccDeeperSimple, AlwaysInclude);
1520 // GPU/CPU no need set any context.
1521 if (!ExistTarget(all_nodes, kAscendDevice)) {
1522 return;
1523 }
1524
1525 // GRAPH | Single Op : KernelByKernel path in MindRT.
1526 if (context_ptr->IsKByKExecutorMode()) {
1527 if (kbk_reason != nullptr) {
1528 *kbk_reason = "Run graph mode with kernel by kernel by configuration.";
1529 MS_LOG(INFO) << *kbk_reason;
1530 }
1531 set_ctx(false, false, false);
1532 return;
1533 }
1534
1535 // GRAPH | Dynamic Shape : KernelByKernel path in MindRT.
1536 if (common::AnfAlgo::IsDynamicGraph(func_graph) && (context_ptr->backend_policy() != "ge")) {
1537 if (kbk_reason != nullptr) {
1538 *kbk_reason =
1539 "Run graph mode with kernel by kernel because graph exist dynamic shape. Call "
1540 "'set_context(save_graphs=True)' to check graph irs.";
1541 MS_LOG(INFO) << *kbk_reason;
1542 }
1543 set_ctx(false, false, false);
1544 return;
1545 }
1546
1547 // GRAPH | Dynamic Scalar : Dynamic scalar ops in graph.
1548 if (IsNeedBackoffGraph(func_graph) && !common::AnfAlgo::IsDynamicGraph(func_graph)) {
1549 if (kbk_reason != nullptr) {
1550 *kbk_reason = "Run graph mode with kernel by kernel because graph exist dynamic scalar ops.";
1551 MS_LOG(INFO) << *kbk_reason;
1552 }
1553 set_ctx(false, false, false);
1554 return;
1555 }
1556 if (!SetModeForControlFlow(func_graph, all_nodes, pynative_mode, backend_ptr)) {
1557 return;
1558 }
1559
1560 #if defined(__linux__) && defined(WITH_BACKEND)
1561 if (ps::PSContext::instance()->cache_enable()) {
1562 MS_LOG(INFO) << "Run graph mode with subgraph sink because PS cache enable.";
1563 set_ctx(true, false, false);
1564 return;
1565 }
1566 #endif
1567
1568 // GRAPH | normal network and if/for/switch scenario etc : MultiGraph path in MindRT.
1569 MS_LOG(INFO) << "Run graph mode with multi graph sink.";
1570 set_ctx(true, true, !pynative_mode);
1571 return;
1572 }
1573
OriginSetRunMode(const ResourcePtr & resource)1574 void OriginSetRunMode(const ResourcePtr &resource) {
1575 FuncGraphPtr func_graph = resource->func_graph();
1576 MS_EXCEPTION_IF_NULL(func_graph);
1577 auto bc_ptr = resource->GetBackend();
1578 auto context_ptr = MsContext::GetInstance();
1579 std::string backend = MsContext::GetInstance()->backend_policy();
1580 MS_EXCEPTION_IF_NULL(context_ptr);
1581 auto task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
1582 if (func_graph->exist_multi_target() || !task_sink) {
1583 bc_ptr->set_is_multi_graph_sink(false);
1584 context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
1585 context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
1586 } else if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
1587 std::string device_target = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
1588 auto manager = func_graph->manager();
1589 auto graphs = manager->func_graphs();
1590 if (graphs.size() > 1 && device_target == kAscendDevice) {
1591 MS_LOG(INFO) << "This func_graph has control flow nodes, owns " << graphs.size() << " subgraphs.";
1592 }
1593 bool exist_while =
1594 std::any_of(graphs.cbegin(), graphs.cend(), [](const FuncGraphPtr &fg) { return fg->recursive(); });
1595 if (device_target == kAscendDevice && backend != kMsVm && !exist_while) {
1596 MS_LOG(INFO) << "Run graph mode with multigraph sink.";
1597 bc_ptr->set_is_multi_graph_sink(true);
1598 context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, true);
1599 } else {
1600 MS_LOG(INFO) << "Run graph mode with vm.";
1601 bc_ptr->set_is_multi_graph_sink(false);
1602 context_ptr->set_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK, false);
1603 context_ptr->set_param<bool>(MS_CTX_ENABLE_LOOP_SINK, false);
1604 }
1605 }
1606 }
1607
SetRunMode(const ResourcePtr & resource)1608 void SetRunMode(const ResourcePtr &resource) {
1609 MS_EXCEPTION_IF_NULL(resource);
1610 auto context_ptr = MsContext::GetInstance();
1611 MS_EXCEPTION_IF_NULL(context_ptr);
1612 // The root cause of KernelByKernel mode should be returned.
1613 std::string kbk_reason = "";
1614 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1615 SetRunMode(resource->func_graph(), resource->GetBackend().get(), &kbk_reason);
1616 } else {
1617 OriginSetRunMode(resource);
1618 }
1619 auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
1620 auto is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
1621 auto enable_hccl = context_ptr->get_param<bool>(MS_CTX_ENABLE_HCCL);
1622 if ((!is_task_sink ||
1623 (context_ptr->IsKByKExecutorMode() && common::AnfAlgo::IsDynamicGraph(resource->func_graph()))) &&
1624 mode == kGraphMode && enable_hccl && !common::UseHostCollective() && common::GetEnv(kSimulationLevel).empty()) {
1625 MS_LOG(INTERNAL_EXCEPTION) << "Current execution mode is 'kernelbykernel', reason: " << kbk_reason
1626 << ", but you're launching job using 'ranktable', which "
1627 "does not support 'kernelbykernel' mode.\n Please refer to link: "
1628 "https://www.mindspore.cn/tutorials/experts/en/master/parallel/startup_method.html "
1629 "and use 'Dynamic cluster'(suggested) or 'mpirun' to launch your job.";
1630 }
1631 }
1632
TaskEmitAction(const ResourcePtr & resource)1633 bool TaskEmitAction(const ResourcePtr &resource) {
1634 EventMessage::PrintCompileStatusMessage("Start generating kernels.");
1635 MS_EXCEPTION_IF_NULL(resource);
1636 FuncGraphPtr func_graph = resource->func_graph();
1637 if (func_graph == nullptr) {
1638 MS_LOG(INTERNAL_EXCEPTION) << "TaskEmit args error";
1639 }
1640 auto context_ptr = MsContext::GetInstance();
1641 MS_EXCEPTION_IF_NULL(context_ptr);
1642 auto mode = context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE);
1643 if (mode == kGraphMode && CheckGraphOutputConstOrParameter(func_graph)) {
1644 return true;
1645 }
1646
1647 // In PyNative mode, multi target will generate in -1 shape in jit. But, jit in -1 shape will run as a call graph;
1648 // control flow not has flag kFlagJitCallGraph
1649 bool is_control_flow = !func_graph->func_graphs_used_total().empty();
1650 if (mode == kGraphMode || (mode == kPynativeMode && (func_graph->has_flag(kFlagJitCallGraph) || is_control_flow))) {
1651 func_graph->SetMultiTarget();
1652 if (func_graph->exist_multi_target() && DumpJsonParser::GetInstance().IsDumpEnabled()) {
1653 MS_LOG(WARNING) << "Multi device target is detected, CPU data is dumped in rank_0 directory";
1654 }
1655 }
1656 DisableMindRT(resource);
1657
1658 SetRunMode(resource);
1659 auto bc_ptr = resource->GetBackend();
1660 MS_EXCEPTION_IF_NULL(bc_ptr);
1661 const auto &backend = context_ptr->backend_policy();
1662 // The graph compiling of mindRT.
1663 if ((backend == kMsConvert || backend == kGeVm) && context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1664 TaskEmitActionForMindRT(resource);
1665 return true;
1666 }
1667 // The graph compiling of control sink.
1668 if (IsCtrlSink() && (backend == kMsConvert || backend == kGeVm)) {
1669 auto graph_id = bc_ptr->CompileGraph(NOT_NULL(func_graph));
1670 resource->SetResult(kOutput, graph_id);
1671 return true;
1672 }
1673 std::vector<PrimitivePtr> cut_list = compile::GetNonlinearOps();
1674 if (bc_ptr->name() == kMsConvert || bc_ptr->name() == kGeVm) {
1675 cut_list = compile::GetMsNonlinearOps();
1676 }
1677 std::shared_ptr<CompileGraphs> compile = std::make_shared<CompileGraphs>(bc_ptr, cut_list);
1678 auto vm = compile->CompileAndLink(func_graph);
1679 resource->SetResult(kOutput, vm);
1680 return true;
1681 }
1682
ExecuteAction(const ResourcePtr & resource)1683 bool ExecuteAction(const ResourcePtr &resource) {
1684 MS_EXCEPTION_IF_NULL(resource);
1685 if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode &&
1686 CheckGraphOutputConstOrParameter(resource->func_graph())) {
1687 return true;
1688 }
1689 if (!resource->HasResult(kOutput)) {
1690 MS_LOG(INTERNAL_EXCEPTION) << "Execute args error";
1691 }
1692 std::string backend = MsContext::GetInstance()->backend_policy();
1693 // The graph running of mindRT.
1694 if ((backend == kMsConvert || backend == kGeVm) && MsContext::GetInstance()->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
1695 ExecuteActionForMindRT(resource);
1696 return true;
1697 }
1698
1699 // The graph running of control sink.
1700 if (IsCtrlSink() && (backend == kMsConvert || backend == kGeVm)) {
1701 auto graph_id = resource->GetResult(kOutput).cast<GraphId>();
1702 auto bc_ptr = resource->GetBackend();
1703 compile::MsBackend *msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr).get();
1704 MS_EXCEPTION_IF_NULL(msbc_ptr);
1705 compile::VmEvalFuncPtr run =
1706 std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef {
1707 MS_LOG(INFO) << "Execute args size " << args.size();
1708 auto outs = msbc_ptr->RunGraph(graph_id, args);
1709 MS_LOG(DEBUG) << "out size " << outs.size();
1710 return outs[0];
1711 });
1712 resource->SetResult(kOutput, run);
1713 return true;
1714 }
1715
1716 compile::FinalVMPtr vm = resource->GetResult(kOutput).cast<compile::FinalVMPtr>();
1717 if (vm == nullptr) {
1718 MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM";
1719 return true;
1720 }
1721 compile::VmEvalFuncPtr run =
1722 std::make_shared<compile::VmEvalFunc>(std::bind(&compile::FinalVM::Eval, vm, std::placeholders::_1));
1723 resource->SetResult(kOutput, run);
1724 return true;
1725 }
1726
1727 #if defined(__linux__) && defined(WITH_BACKEND)
DistributedSplitAction(const ResourcePtr & resource)1728 bool DistributedSplitAction(const ResourcePtr &resource) {
1729 // Only run this action when the cluster is initialized.
1730 if (!distributed::cluster::ClusterContext::instance()->initialized()) {
1731 return true;
1732 }
1733 MS_EXCEPTION_IF_NULL(resource);
1734 FuncGraphPtr func_graph = resource->func_graph();
1735 auto node = distributed::cluster::ClusterContext::instance()->node();
1736 MS_EXCEPTION_IF_NULL(node);
1737 auto node_role = distributed::cluster::ClusterContext::instance()->node_role();
1738
1739 parallel::GraphSplitterPtr splitter =
1740 std::make_shared<parallel::GraphSplitter>(func_graph, node->rank_id(), node_role);
1741 MS_EXCEPTION_IF_NULL(splitter);
1742 splitter->Run();
1743 // Renomalize: Infer shape and Set abstract for all nodes in graph.
1744 if (func_graph->has_flag(kFlagNeedRenormalize)) {
1745 abstract::AbstractBasePtrList args_abs;
1746 auto parameters = func_graph->parameters();
1747 (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_abs),
1748 [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
1749 FuncGraphPtr new_fg = Renormalize(resource, func_graph, args_abs);
1750 resource->set_func_graph(new_fg);
1751 resource->set_args_abs(args_abs);
1752 }
1753 return true;
1754 }
1755 #endif
1756
1757 // The parallel primitive related valuenode might be partitioned so that its value changes by device,
1758 // that will result in a synchronization error due to different executing order.
1759 // Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive,
1760 // the final solution will be proposed later as a parallel feature.
KeepValueNodeDuplication(const AnfNodePtr & value_node,const ResourcePtr & resource)1761 bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &resource) {
1762 MS_EXCEPTION_IF_NULL(resource);
1763 MS_EXCEPTION_IF_NULL(resource->manager());
1764 auto &node_users = resource->manager()->node_users();
1765 auto &users = node_users[value_node];
1766 auto used_by_keep_value_prim =
1767 std::any_of(users.begin(), users.end(), [](const std::pair<AnfNodePtr, int64_t> &user) -> bool {
1768 MS_EXCEPTION_IF_NULL(user.first);
1769 auto cnode = user.first->cast<CNodePtr>();
1770 if (cnode == nullptr) {
1771 return false;
1772 }
1773 auto prim_node = cnode->input(0);
1774 if (IsValueNode<Primitive>(prim_node)) {
1775 auto prim = GetValue<PrimitivePtr>(prim_node->cast<ValueNodePtr>()->value());
1776 MS_EXCEPTION_IF_NULL(prim);
1777 // value_node is referenced by some parallel primitive
1778 return prim->HasAttr("keep_value_node_input");
1779 }
1780 return false;
1781 });
1782 return used_by_keep_value_prim;
1783 }
1784
RemoveValueNodeDuplicationsAction(const ResourcePtr & resource)1785 bool RemoveValueNodeDuplicationsAction(const ResourcePtr &resource) {
1786 MS_EXCEPTION_IF_NULL(resource);
1787 FuncGraphPtr func_graph = resource->func_graph();
1788 if (func_graph == nullptr) {
1789 MS_LOG(INTERNAL_EXCEPTION) << "Remove value node duplications error.";
1790 }
1791 auto manager = resource->manager();
1792 // Remove duplicated value nodes, due to replace operation, can't use reference.
1793 auto value_nodes = func_graph->value_nodes();
1794 HashCache hash_cache;
1795 HashValue hashes;
1796 for (const auto &value_pair : value_nodes) {
1797 if (KeepValueNodeDuplication(value_pair.first, resource)) {
1798 continue;
1799 }
1800 TryToDoReplace(manager.get(), value_pair.first, &hash_cache, &hashes);
1801 }
1802 return true;
1803 }
1804
PipelineSplitAction(const ResourcePtr & resource)1805 bool PipelineSplitAction(const ResourcePtr &resource) { return PipelineSplitPass(resource); }
1806
ParallelVirtualDatasetAction(const ResourcePtr & resource)1807 bool ParallelVirtualDatasetAction(const ResourcePtr &resource) { return ParallelVirtualDatasetPass(resource); }
1808
AutoParallelSymbolWithReNormalizeAction(const ResourcePtr & resource)1809 bool AutoParallelSymbolWithReNormalizeAction(const ResourcePtr &resource) {
1810 return AutoParallelSymbolPassWithReNormalize(resource);
1811 }
PipelineSchedulerAction(const ResourcePtr & resource)1812 bool PipelineSchedulerAction(const ResourcePtr &resource) { return PipelineParallelScheduler(resource); }
1813
AutoParallelAction(const ResourcePtr & resource)1814 bool AutoParallelAction(const ResourcePtr &resource) { return AutoParallelPass(resource); }
1815
ValidateAction(const ResourcePtr & resource)1816 bool ValidateAction(const ResourcePtr &resource) {
1817 auto res = ValidatePass(resource);
1818 #ifdef DEBUG
1819 FuncGraphLoopBreaker::Inst().Dump();
1820 #endif
1821 return res;
1822 }
1823
SetMindIRGraphAction(const ResourcePtr & resource)1824 bool SetMindIRGraphAction(const ResourcePtr &resource) {
1825 MS_EXCEPTION_IF_NULL(resource);
1826 resource->set_is_load(true);
1827 auto cell = py::cast<CellPtr>(resource->source_input());
1828 if (cell == nullptr) {
1829 MS_LOG(INTERNAL_EXCEPTION) << "The graph loaded from mindir is null.";
1830 }
1831 const std::string mindir_graph = "graph_load_from_mindir";
1832 auto obj = cell->GetAttr(mindir_graph);
1833 if (obj == nullptr) {
1834 MS_LOG(INTERNAL_EXCEPTION) << "The graph loaded from mindir is null. The cell has not attribute: " << mindir_graph;
1835 }
1836 auto fg = GetValue<FuncGraphPtr>(obj);
1837 if (fg == nullptr) {
1838 MS_LOG(INTERNAL_EXCEPTION) << "The graph loaded from mindir is null.";
1839 }
1840 resource->set_func_graph(fg);
1841 FuncGraphManagerPtr mng = fg->manager();
1842 if (mng == nullptr) {
1843 auto res_mng = resource->manager();
1844 MS_EXCEPTION_IF_NULL(res_mng);
1845 res_mng->Clear();
1846 res_mng->AddFuncGraph(fg);
1847 }
1848 abstract::AbstractBasePtrList broaded_args;
1849 const auto &args_abs_list = resource->args_abs();
1850 (void)std::transform(args_abs_list.begin(), args_abs_list.end(), std::back_inserter(broaded_args),
1851 [](const AbstractBasePtr &arg) -> AbstractBasePtr {
1852 MS_EXCEPTION_IF_NULL(arg);
1853 if (arg->GetValueTrack() != kValueAny) {
1854 return arg->Broaden();
1855 }
1856 return arg;
1857 });
1858
1859 abstract::AbstractBasePtrList func_args;
1860 const auto inputs = fg->get_inputs();
1861 (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(func_args),
1862 [](const AnfNodePtr &arg) -> AbstractBasePtr {
1863 MS_EXCEPTION_IF_NULL(arg);
1864 auto abs = arg->abstract();
1865 MS_EXCEPTION_IF_NULL(abs);
1866 return abs->Broaden();
1867 });
1868
1869 bool is_equal_input_args = true;
1870 if (!AbstractBasePtrListDeepEqual(func_args, broaded_args)) {
1871 MS_LOG(INFO) << "The input arguments is not compatible with the function graph which has been exported before."
1872 << "Please check the args is same with export.\n"
1873 << "The export input argument size: " << func_args.size() << "\n"
1874 << "The load input argument size: " << broaded_args.size() << "\n"
1875 << "Export input args info: " << abstract::ArgsToString(func_args) << "\n"
1876 << "The input args info: " << abstract::ArgsToString(broaded_args);
1877 is_equal_input_args = false;
1878 }
1879
1880 if (!is_equal_input_args) {
1881 // Use InferMindir which will find c++ infer in eval_map and backend_eval_map;
1882 (void)InferMindir(resource->func_graph(), args_abs_list, true);
1883 }
1884 return true;
1885 }
1886
CommonPipeline(bool trace_flag)1887 static std::vector<ActionItem> CommonPipeline(bool trace_flag) {
1888 std::vector<ActionItem> actions;
1889 auto graph_executor = pipeline::GraphExecutorPy::GetInstance();
1890 MS_EXCEPTION_IF_NULL(graph_executor);
1891 const bool boost_infer = common::GetEnv("MS_DEV_BOOST_INFER") != "0" && graph_executor->graph_cell_count() == 0;
1892 if (!trace_flag) {
1893 if (boost_infer) {
1894 // Bootstrap for JIT.
1895 (void)actions.emplace_back(std::make_pair(kBootstrap, BootstrapAction));
1896 } else {
1897 // Parse the python ast to ANF graph
1898 (void)actions.emplace_back(std::make_pair(kParse, ParseAction));
1899
1900 // Resolve the python func
1901 (void)actions.emplace_back(std::make_pair(kSymbolResolve, SymbolResolveAction));
1902
1903 // Notice: Temporary solution, to be implemented using Python Rewriter in the future.
1904 // Set mixed Precision flag in subgraph.
1905 static bool enable_set_mixed_precision_flag = (common::GetCompileConfig("AMP_ENABLE_ALL_FG") == "1");
1906 if (enable_set_mixed_precision_flag) {
1907 (void)actions.emplace_back(std::make_pair(kSetMixedPrecisionFlag, SetMixedPrecisionAction));
1908 }
1909
1910 auto parallel_context = parallel::ParallelContext::GetInstance();
1911 MS_EXCEPTION_IF_NULL(parallel_context);
1912 auto parallel_mode = parallel_context->parallel_mode();
1913 const bool is_parallel_mode =
1914 parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
1915 static const auto combine_like_graphs = (common::GetCompileConfig("COMBINE_LIKE_GRAPHS") == "1");
1916 static const auto force_disable_combine = (common::GetCompileConfig("COMBINE_LIKE_GRAPHS") == "0");
1917 if (!is_cluster_initialized && (!is_parallel_mode || combine_like_graphs) && !force_disable_combine) {
1918 (void)actions.emplace_back(std::make_pair(kCombineLikeGraphs, CombineLikeGraphs));
1919 }
1920
1921 // Make the reusable cell to be the reusable function graph
1922 (void)actions.emplace_back(std::make_pair(kGraphReusing, GraphReusingAction));
1923
1924 // Pre-Lift the func graphs.
1925 (void)actions.emplace_back(std::make_pair(kPreCConv, PreCConvAction));
1926 }
1927 }
1928 // Evaluate type and shape, and specialize.
1929 (void)actions.emplace_back(std::make_pair(kTypeInference, TypeInferenceAction));
1930
1931 // Auto-monad for side-effects handling.
1932 (void)actions.emplace_back(std::make_pair(kAutoMonad, AutoMonadAction));
1933
1934 if (boost_infer) {
1935 (void)actions.emplace_back(std::make_pair(kGraphReusing, GraphReusingAction));
1936 }
1937
1938 // Do data structure simplifications and inline.
1939 (void)actions.emplace_back(std::make_pair(kInline, OptInlineAction));
1940
1941 (void)actions.emplace_back(std::make_pair("parallel-infer-symbol", AutoParallelSymbolWithReNormalizeAction));
1942 // Do prepositive auto parallel.
1943 (void)actions.emplace_back(std::make_pair(kPreAutoParallel, AutoParallelAction));
1944 // insert virtual dataset
1945 (void)actions.emplace_back(std::make_pair("insert-virtual-dataset", ParallelVirtualDatasetAction));
1946 (void)actions.emplace_back(std::make_pair("parallel-infer-symbol-second", AutoParallelSymbolWithReNormalizeAction));
1947 // Do PipelineSplit action.
1948 (void)actions.emplace_back(std::make_pair(kPipelineSplit, PipelineSplitAction));
1949
1950 return actions;
1951 }
1952
EraseParseActions(const std::vector<ActionItem> & actions)1953 std::vector<ActionItem> EraseParseActions(const std::vector<ActionItem> &actions) {
1954 std::vector<ActionItem> filtered_actions;
1955 for (const auto &item : actions) {
1956 if (item.first != "parse") {
1957 (void)filtered_actions.emplace_back(item);
1958 }
1959 }
1960 return filtered_actions;
1961 }
1962
VmPipeline(const ResourcePtr & resource,bool trace_flag,bool erase_parse)1963 std::vector<ActionItem> VmPipeline(const ResourcePtr &resource, bool trace_flag, bool erase_parse) {
1964 is_cluster_initialized = distributed::cluster::ClusterContext::instance()->initialized();
1965 std::vector<ActionItem> actions;
1966 // If enable compilation cache and the cache is read successfully, only do the backend actions.
1967 const std::string &phase = PhaseManager::GetInstance().phase();
1968 if (IsPhaseLoadFromMindIR(phase)) {
1969 actions = MindIRPipeline();
1970 } else if (!resource->EnableCompileCache() || resource->func_graph() == nullptr) {
1971 actions = CommonPipeline(trace_flag);
1972
1973 // Optimize
1974 (void)actions.emplace_back(std::make_pair(kOptimize, VmOptimizeAction));
1975
1976 (void)actions.emplace_back(std::make_pair(kPipelineParallelScheduler, PipelineSchedulerAction));
1977
1978 (void)actions.emplace_back(std::make_pair(kAutoMonadReorder, OrderEnforceAction));
1979
1980 // Eliminate forward cnode for grad graph
1981 (void)actions.emplace_back(std::make_pair(kGetJitBpropGraph, GetJitBpropGraph));
1982
1983 // Rewriter(dict convert pyexecute) after jit bprop.
1984 (void)actions.emplace_back(std::make_pair(kRewriterAfterJitBprop, RewriterAfterOptAPassAfterJitBprop));
1985
1986 // Eliminate the virtual mirror node
1987 (void)actions.emplace_back(std::make_pair(kEliminateSpecialOpNode, EliminateSpecialOpNode));
1988
1989 #if defined(__linux__) && defined(WITH_BACKEND)
1990 if (!pipeline::IsPhaseExport(phase)) {
1991 (void)actions.emplace_back(std::make_pair(kDistributedSplit, DistributedSplitAction));
1992 }
1993 if (ps::PSContext::instance()->is_worker()) {
1994 if (distributed::cluster::ClusterContext::instance()->initialized()) {
1995 MS_LOG(INFO) << "This worker is initialized. No need to add worker action.";
1996 } else {
1997 std::string server_mode = ps::PSContext::instance()->server_mode();
1998 }
1999 }
2000 #endif
2001
2002 // Mind Compiler finish.
2003 (void)actions.emplace_back(std::make_pair(kValidate, ValidateAction));
2004 }
2005
2006 if (erase_parse) {
2007 actions = EraseParseActions(actions);
2008 }
2009
2010 auto is_precompile_only = MsContext::GetInstance()->get_param<bool>(MS_CTX_PRECOMPILE_ONLY);
2011 if (is_precompile_only) {
2012 MS_LOG(INFO) << "PrecompileOnly, stop run graph";
2013 return actions;
2014 }
2015
2016 if (common::GetEnv(kSimulationLevel) == kSimulationLevelCompileGraph) {
2017 return actions;
2018 }
2019
2020 auto ms_context = MsContext::GetInstance();
2021 MS_EXCEPTION_IF_NULL(ms_context);
2022 #ifndef WITH_BACKEND
2023 if (ms_context->backend_policy() != "ge") {
2024 #endif
2025 // Phase with "export" prefix need to skip backend compilation.
2026 if (pipeline::IsPhaseExport(phase)) {
2027 return actions;
2028 }
2029 // Compile the ANF graph
2030 (void)actions.emplace_back(std::make_pair(kTaskEmit, TaskEmitAction));
2031
2032 // Execute the graph
2033 (void)actions.emplace_back(std::make_pair(kExecute, ExecuteAction));
2034 #ifndef WITH_BACKEND
2035 }
2036 #endif
2037 return actions;
2038 }
2039
MindIRPipeline()2040 std::vector<ActionItem> MindIRPipeline() {
2041 auto context_ptr = MsContext::GetInstance();
2042 if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
2043 MS_LOG(EXCEPTION)
2044 << "The graph generated form MindIR is not support to execute in the PynativeMode, please convert "
2045 "to the GraphMode.";
2046 }
2047 std::vector<ActionItem> actions;
2048 // Set funcGraph loaded from MindIR to resource.
2049 (void)actions.emplace_back(std::make_pair(kLoadMindir, SetMindIRGraphAction));
2050 (void)actions.emplace_back(std::make_pair(kModifyMindirGraph, ModifyGraphGeneratedByMindIR));
2051 (void)actions.emplace_back(std::make_pair(kInferMindir, InferMindIR));
2052 (void)actions.emplace_back(std::make_pair(kValidate, ValidateAction));
2053 return actions;
2054 }
2055 } // namespace pipeline
2056 } // namespace mindspore
2057