• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3  *
4  * Copyright 2019-2024 Huawei Technologies Co., Ltd
5  *
6  * Licensed under the Apache License, Version 2.0 (the "License");
7  * you may not use this file except in compliance with the License.
8  * You may obtain a copy of the License at
9  *
10  * http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "pipeline/jit/ps/pipeline.h"
20 
21 #include <memory>
22 #include <map>
23 #include <cstdlib>
24 #include <algorithm>
25 #include <iomanip>
26 #include <unordered_map>
27 #include <functional>
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "pybind_api/pybind_patch.h"
30 #include "pybind11/pybind11.h"
31 #include "ir/param_info.h"
32 #include "pipeline/jit/ps/action.h"
33 #include "pipeline/jit/ps/pass.h"
34 #include "pipeline/jit/ps/parse/data_converter.h"
35 #include "pipeline/jit/ps/static_analysis/async_eval_result.h"
36 #include "pipeline/jit/ps/compile_cache_manager.h"
37 #include "pipeline/pynative/pynative_execute.h"
38 #include "frontend/optimizer/ad/dfunctor.h"
39 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
40 #include "include/common/utils/parallel_context.h"
41 #include "frontend/parallel/step_parallel_utils.h"
42 #include "frontend/parallel/parameter_manager.h"
43 #include "frontend/parallel/graph_util/get_parallel_info.h"
44 #include "frontend/parallel/graph_util/flops_collection.h"
45 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
46 #include "frontend/parallel/step_auto_parallel.h"
47 #include "frontend/parallel/step_parallel.h"
48 #include "frontend/parallel/device_manager.h"
49 #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
50 #include "frontend/parallel/pass/handle_group_info.h"
51 #include "frontend/parallel/step_assigned_parallel.h"
52 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
53 #include "frontend/expander/utils.h"
54 #include "include/common/utils/config_manager.h"
55 #include "include/common/utils/convert_utils.h"
56 #include "include/common/utils/convert_utils_py.h"
57 #include "include/common/utils/python_utils.h"
58 #include "utils/log_adapter.h"
59 #include "utils/ms_context.h"
60 #include "utils/shape_utils.h"
61 #include "utils/info.h"
62 #include "utils/crypto.h"
63 #include "utils/phase.h"
64 #include "utils/compile_config.h"
65 #include "include/common/utils/comm_manager.h"
66 #include "include/common/utils/stub_tensor.h"
67 #include "utils/interpret_node_recorder.h"
68 #include "include/common/debug/anf_ir_dump.h"
69 #include "include/common/debug/dump_proto.h"
70 #include "pipeline/jit/ps/fallback.h"
71 #include "pipeline/jit/ps/debug/trace.h"
72 #include "pipeline/jit/ps/event_message_print.h"
73 #include "include/common/debug/draw.h"
74 #include "include/common/debug/common.h"
75 #include "load_mindir/load_model.h"
76 #include "backend/graph_compiler/segment_runner.h"
77 #include "backend/common/session/executor_manager.h"
78 #include "backend/common/session/session_factory.h"
79 #include "runtime/hardware/device_context_manager.h"
80 #include "runtime/device/kernel_runtime_manager.h"
81 #include "runtime/pynative/op_executor.h"
82 #include "runtime/device/stream_synchronizer.h"
83 #include "include/common/fallback.h"
84 #include "include/common/profiler.h"
85 #include "include/backend/distributed/collective/collective_manager.h"
86 #include "include/backend/distributed/recovery/recovery_context.h"
87 #include "include/common/utils/dynamic_obfuscation/dynamic_obfuscation.h"
88 #include "include/common/utils/dynamic_obfuscation/registry_opaque_predicate.h"
89 #include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
90 #include "include/backend/distributed/init.h"
91 #include "include/backend/debug/profiler/profiling.h"
92 #include "kernel/graph_kernel/graph_kernel_builder_manager.h"
93 #include "kernel/graph_kernel_info.h"
94 #include "include/backend/data_queue/data_queue_mgr.h"
95 #include "mindspore/core/symbolic_shape/symbol_info.h"
96 #include "include/common/symbol_engine/symbol_engine_impl.h"
97 #include "pipeline/jit/ps/load_mindir.h"
98 #include "load_mindir/infer_mindir.h"
99 
100 #ifndef ENABLE_SECURITY
101 #include "include/backend/debug/data_dump/dump_json_parser.h"
102 #include "include/backend/debug/data_dump/acl_dump_json_writer.h"
103 #include "abstract/abstract_value.h"
104 #endif
105 #if defined(__linux__) && defined(WITH_BACKEND)
106 #include "include/backend/distributed/ps/constants.h"
107 #include "include/backend/distributed/ps/util.h"
108 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h"
109 #include "include/backend/distributed/cluster/cluster_context.h"
110 #include "runtime/graph_scheduler/embedding_cache_scheduler.h"
111 #include "include/backend/distributed/ps/ps_context.h"
112 #include "include/backend/distributed/embedding_cache/data_queue_manager.h"
113 #endif
114 #ifdef ENABLE_DUMP_IR
115 #include "debug/rdr/graph_recorder.h"
116 #include "include/common/debug/rdr/recorder_manager.h"
117 #include "ir/cell.h"
118 #endif
119 
120 #include "pybind_api/ir/log_adapter_py.h"  // Only include one-time in the whole project.
121 #include "pybind_api/ir/py_execute_py.h"   // Only include one-time in the whole project.
122 #include "include/common/utils/compile_cache_context.h"
123 
124 namespace mindspore {
125 // namespace to support intermediate representation definition
126 namespace pipeline {
127 using Tensor = mindspore::tensor::Tensor;
128 using MetaTensor = mindspore::tensor::MetaTensor;
129 using MetaSparseTensor = mindspore::tensor::MetaSparseTensor;
130 using CSRTensor = mindspore::tensor::CSRTensor;
131 using COOTensor = mindspore::tensor::COOTensor;
132 using mindspore::abstract::AbstractTensor;
133 using mindspore::abstract::AbstractTensorPtr;
134 using mindspore::abstract::AbstractTuple;
135 using mindspore::abstract::AbstractTuplePtr;
136 using DeviceTensor = mindspore::device::DeviceAddress;
137 
138 const char IR_TYPE_ANF[] = "anf_ir";
139 const char IR_TYPE_ONNX[] = "onnx_ir";
140 const char IR_TYPE_MINDIR[] = "mind_ir";
141 
142 GraphExecutorPyPtr GraphExecutorPy::executor_ = nullptr;
143 std::mutex GraphExecutorPy::instance_lock_;
144 
145 std::unordered_map<abstract::AbstractBasePtrList, uint64_t, abstract::AbstractBasePtrListHasher,
146                    abstract::AbstractBasePtrListEqual>
147   kArgsCache;
148 std::unordered_map<PyObject *, abstract::AbstractBasePtrList> kCellArgsMap;
149 
150 namespace {
151 #ifdef ENABLE_DUMP_IR
GetBaseNameForIR(int64_t stage_idx,const std::string & action_name)152 std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) {
153   std::ostringstream oss;
154   int spaces = 2;
155   oss << std::setfill('0') << std::setw(spaces) << stage_idx << "_" << action_name;
156   return oss.str();
157 }
158 #endif
159 
CheckAllTensor(const ValueTuplePtr & value_tuple)160 bool CheckAllTensor(const ValueTuplePtr &value_tuple) {
161   auto elements = value_tuple->value();
162   for (auto element : elements) {
163     MS_EXCEPTION_IF_NULL(element);
164     if (!(element->isa<ValueTuple>() && CheckAllTensor(element->cast<ValueTuplePtr>())) &&
165         !(element->isa<MetaTensor>())) {
166       return false;
167     }
168   }
169   return true;
170 }
171 
Mutable(const py::object & obj,const ValuePtr & value)172 bool Mutable(const py::object &obj, const ValuePtr &value) {
173   // If a tensor has been set const arg, it should not be mutable.
174   if (value->isa<MetaTensor>()) {
175     constexpr char const_arg_attr[] = "const_arg";
176     if (py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr))) {
177       return false;
178     }
179   }
180   constexpr char mutable_attr[] = "__ms_mutable__";
181   return py::hasattr(obj, mutable_attr) && py::cast<bool>(py::getattr(obj, mutable_attr));
182 }
183 
CheckAndConvertToVariableLenSequence(const py::object & obj,AbstractBasePtr abs)184 bool CheckAndConvertToVariableLenSequence(const py::object &obj, AbstractBasePtr abs) {
185   constexpr char variable_len_attr[] = "__ms_dynamic_len__";
186   bool dynamic_len = (py::hasattr(obj, variable_len_attr) && py::cast<bool>(py::getattr(obj, variable_len_attr)));
187   if (!dynamic_len) {
188     return false;
189   }
190   if (!abs->isa<abstract::AbstractSequence>()) {
191     MS_EXCEPTION(TypeError) << "For mutable, when the dynamic_len the True, the first input should be"
192                             << " list or tuple, but got: " << abs->ToString();
193   }
194   auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
195   abs_seq->CheckAndConvertToDynamicLenSequence();
196   return true;
197 }
198 
TensorArgMutable(const py::object & obj,const ValuePtr & value)199 bool TensorArgMutable(const py::object &obj, const ValuePtr &value) {
200   if (!value->isa<MetaTensor>()) {
201     return false;
202   }
203   constexpr char const_arg_attr[] = "const_arg";
204   return !py::hasattr(obj, const_arg_attr) || !py::cast<bool>(py::getattr(obj, const_arg_attr));
205 }
206 
EnableTupleBroaden(const ValuePtr & value,bool enable_tuple_broaden)207 bool EnableTupleBroaden(const ValuePtr &value, bool enable_tuple_broaden) {
208   return enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>());
209 }
210 
GradForScalar(const ValuePtr & value)211 bool GradForScalar(const ValuePtr &value) {
212   return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>();
213 }
214 
ArgsToAbstract(const py::object & arg,const ValuePtr & value,bool enable_tuple_broaden=false)215 AbstractBasePtr ArgsToAbstract(const py::object &arg, const ValuePtr &value, bool enable_tuple_broaden = false) {
216   bool broaden = TensorArgMutable(arg, value) || Mutable(arg, value) || value->isa<MetaSparseTensor>() ||
217                  EnableTupleBroaden(value, enable_tuple_broaden) || GradForScalar(value);
218   auto ret = abstract::ToAbstract(value, nullptr, nullptr);
219   if (broaden) {
220     ret = AbstractBroaden(ret);
221   }
222   auto is_dynamic_len = CheckAndConvertToVariableLenSequence(arg, ret);
223   if (fallback::EnableFallbackListDictInplace() && !broaden && !is_dynamic_len) {
224     // Attach corresponding list python object for constant list input.
225     fallback::AttachPyObjToAbs(ret, arg, false);
226   }
227   return ret;
228 }
229 
CheckArgValid(const py::handle & arg)230 bool CheckArgValid(const py::handle &arg) {
231   if (py::isinstance<py::list>(arg) || py::isinstance<py::tuple>(arg)) {
232     auto vector_arg = py::cast<py::list>(arg);
233     return std::all_of(vector_arg.begin(), vector_arg.end(), CheckArgValid);
234   }
235 
236   if (py::isinstance<py::dict>(arg)) {
237     auto dict_arg = py::cast<py::dict>(arg);
238     return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
239   }
240 
241   if (py::isinstance<Tensor>(arg) || IsStubTensor(arg)) {
242     auto tensor = IsStubTensor(arg) ? ConvertStubTensor(arg) : py::cast<TensorPtr>(arg);
243     if (tensor->data_type() == kNumberTypeBool) {
244       MS_LOG(INFO) << "It is not recommended to use a tensor of bool data type as network input, which may cause "
245                    << "operator compilation failure. For more details, please refer to the FAQ at "
246                    << "https://mindspore.cn/search?[AddN]%20input(kNumberTypeBool.";
247     }
248   }
249 
250   return IsStubTensor(arg) || py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) ||
251          py::isinstance<py::none>(arg) || py::isinstance<Number>(arg) || py::isinstance<py::str>(arg) ||
252          py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg) || py::isinstance<COOTensor>(arg);
253 }
254 
GetCompileExceptionInfo()255 std::string GetCompileExceptionInfo() {
256   std::ostringstream oss;
257   trace::GetTraceStackInfo(oss);
258   return oss.str();
259 }
260 
SetLoopCount(const ResourcePtr & resource)261 void SetLoopCount(const ResourcePtr &resource) {
262   MS_EXCEPTION_IF_NULL(resource);
263   auto func_graph = resource->func_graph();
264   if (func_graph != nullptr && func_graph->manager() != nullptr) {
265     auto manager = func_graph->manager();
266     size_t graph_nums = manager->func_graphs().size();
267     int64_t loop_size = ConfigManager::GetInstance().iter_num();
268     const auto context_ptr = MsContext::GetInstance();
269     bool enable_mind_rt = context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT);
270     if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
271       resource->set_vm_loop(!(context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK) || enable_mind_rt), loop_size);
272     } else if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice) {
273       bool run_with_mind_rt = graph_nums == 1 || enable_mind_rt;
274       resource->set_vm_loop(!run_with_mind_rt, loop_size);
275     }
276     MS_LOG(INFO) << "Change vm_loop_flag to " << resource->vm_loop_flag() << ", set loop_size to " << loop_size;
277   }
278 }
279 
GenerateJitConfigMap(const py::dict & jit_config)280 std::map<string, string> GenerateJitConfigMap(const py::dict &jit_config) {
281   std::map<string, string> ret{};
282   for (auto jit_param = jit_config.begin(); jit_param != jit_config.end(); ++jit_param) {
283     auto param_name = py::cast<std::string>(jit_param->first);
284     auto param_value = py::cast<std::string>(jit_param->second);
285     ret[param_name] = param_value;
286   }
287   return ret;
288 }
289 
RecordInitStatus()290 void RecordInitStatus() {
291   static bool printed = false;
292   if (!printed) {
293     MS_LOG(INFO) << "Status record: system init.";
294     printed = true;
295   }
296 }
297 
RecordExitStatus()298 void RecordExitStatus() { MS_LOG(INFO) << "Status record: system exit."; }
299 
ToOrdinal(const size_t & i)300 std::string ToOrdinal(const size_t &i) {
301   auto suffix = "th";
302   if (i == kIndex1) {
303     suffix = "st";
304   } else if (i == kIndex2) {
305     suffix = "nd";
306   } else if (i == kIndex3) {
307     suffix = "rd";
308   }
309   return std::to_string(i) + suffix;
310 }
311 
GetUserDataFromAddress(const py::object & res)312 kernel::PyExecuteOutputUserDataPtr GetUserDataFromAddress(const py::object &res) {
313   const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
314   if (!allow_fallback_runtime) {
315     return nullptr;
316   }
317 
318   if (py::isinstance<tensor::Tensor>(res) || IsStubTensor(res)) {
319     auto res_tensor = IsStubTensor(res) ? ConvertStubTensor(res) : res.cast<tensor::TensorPtr>();
320     MS_EXCEPTION_IF_NULL(res_tensor);
321     if (res_tensor->device_address() != nullptr) {
322       auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(res_tensor->device_address());
323       MS_LOG(DEBUG) << "res tensor_address:" << tensor_address;
324       MS_EXCEPTION_IF_NULL(tensor_address);
325       if (tensor_address->user_data() != nullptr) {
326         return tensor_address->user_data()->get<kernel::PyExecuteOutputUserData>(kernel::PyExecuteOutputUserData::key);
327       }
328     }
329   }
330   return nullptr;
331 }
332 
333 py::object BaseRefToPyDataWithUserData(const BaseRef &value, const AbstractBasePtr &abs);
334 
335 template <typename T>
GetVectorRefPyDataWithAbstract(const VectorRef & value_list,const abstract::AbstractSequencePtr & seq_abs)336 py::object GetVectorRefPyDataWithAbstract(const VectorRef &value_list, const abstract::AbstractSequencePtr &seq_abs) {
337   auto value_size = value_list.size();
338   auto ret = T(value_size);
339 
340   const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
341   size_t ref_idx = 0;
342   for (size_t i = 0; i < seq_abs->size(); ++i) {
343     auto elem_abs = seq_abs->elements()[i];
344     if (elem_abs->isa<abstract::AbstractNone>() && !allow_fallback_runtime) {
345       continue;
346     }
347     ret[ref_idx] = BaseRefToPyDataWithUserData(value_list[ref_idx], elem_abs);
348     ref_idx++;
349   }
350   if (ref_idx != value_size) {
351     MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got "
352                       << ref_idx;
353   }
354   return ret;
355 }
356 
GetVectorRefPyData(const VectorRef & value_list,const AbstractBasePtr & abs)357 py::object GetVectorRefPyData(const VectorRef &value_list, const AbstractBasePtr &abs) {
358   if (abs == nullptr || abs->isa<abstract::AbstractCSRTensor>() || abs->isa<abstract::AbstractCOOTensor>() ||
359       abs->isa<abstract::AbstractAny>()) {
360     return BaseRefToPyData(value_list, abs);
361   }
362   // Need to consider AbstractAny with vector ref scene later.
363   if (!abs->isa<abstract::AbstractSequence>()) {
364     MS_LOG(EXCEPTION) << "Can not convert vector ref with abstract " << abs->ToString();
365   }
366   auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
367   if (seq_abs->dynamic_len()) {
368     return BaseRefToPyData(value_list, abs);
369   }
370   if (seq_abs->isa<abstract::AbstractTuple>()) {
371     return GetVectorRefPyDataWithAbstract<py::tuple>(value_list, seq_abs);
372   }
373   return GetVectorRefPyDataWithAbstract<py::list>(value_list, seq_abs);
374 }
375 
BaseRefToPyDataWithUserData(const BaseRef & value,const AbstractBasePtr & abs)376 py::object BaseRefToPyDataWithUserData(const BaseRef &value, const AbstractBasePtr &abs) {
377   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kGraphExecutorPy, runtime::ProfilerEvent::kOutputProcess,
378                                      "BaseRefToPyData");
379   const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
380   if (!allow_fallback_runtime) {
381     return BaseRefToPyData(value, abs);
382   }
383   if (utils::isa<ValuePtr>(value)) {
384     // Do not use abs as input to BaseRefToPyData, since the res need to be a tensor to get user data.
385     auto res = BaseRefToPyData(value);
386     MS_LOG(DEBUG) << "res: " << py::str(res);
387     const auto user_data = GetUserDataFromAddress(res);
388     if (user_data != nullptr) {
389       return user_data->obj;
390     } else {
391       MS_LOG(DEBUG) << "user data is empty";
392     }
393   } else if (utils::isa<VectorRef>(value)) {
394     auto vec_ref = utils::cast<VectorRef>(value);
395     return GetVectorRefPyData(vec_ref, abs);
396   }
397   return BaseRefToPyData(value, abs);
398 }
399 
AddManager(const FuncGraphManagerPtr & manager,const ValuePtr & value)400 void AddManager(const FuncGraphManagerPtr &manager, const ValuePtr &value) {
401   MS_EXCEPTION_IF_NULL(value);
402   if (value->isa<FuncGraph>()) {
403     auto fg = value->cast<FuncGraphPtr>();
404     manager->AddFuncGraph(fg);
405   }
406   if (value->isa<ValueSequence>()) {
407     auto value_sequence = value->cast<ValueSequencePtr>();
408     for (const auto &elem : value_sequence->value()) {
409       AddManager(manager, elem);
410     }
411   }
412   if (value->isa<ValueDictionary>()) {
413     for (const auto &elem : value->cast<ValueDictionaryPtr>()->value()) {
414       AddManager(manager, elem.second);
415     }
416   }
417 }
418 
AddManagerForFuncGraphArgs(const ResourcePtr & resource,const ValuePtrList & arguments)419 void AddManagerForFuncGraphArgs(const ResourcePtr &resource, const ValuePtrList &arguments) {
420   auto manager = resource->manager();
421   MS_EXCEPTION_IF_NULL(manager);
422   for (const auto &arg : arguments) {
423     AddManager(manager, arg);
424   }
425 }
426 
ResetId(const ResourcePtr & resource)427 void ResetId(const ResourcePtr &resource) {
428   MS_EXCEPTION_IF_NULL(resource);
429 #ifdef ENABLE_DUMP_IR
430   auto context = MsContext::GetInstance();
431   MS_EXCEPTION_IF_NULL(context);
432   auto need_dump = common::GetCompileConfig("DUMP_VALIDATE_BEFORE_RESET_ID");
433   if (context->CanDump(kIntroductory) && need_dump == "1") {
434     FuncGraphPtr graph = resource->func_graph();
435     DumpIR("validate_before_reset_id.ir", graph, true, kWholeStack);
436   }
437 #endif
438   mindspore::id_generator::reset_id();
439   const auto &all_nodes = TopoSort(resource->func_graph()->get_return(), SuccDeeperSimple);
440   for (const auto &node : all_nodes) {
441     if (node != nullptr && node->isa<CNode>()) {
442       const auto &cnode = node->cast<CNodePtr>();
443       MS_EXCEPTION_IF_NULL(cnode);
444       cnode->set_fullname_with_scope("");
445     }
446   }
447 }
448 
CheckShapeConsistency(const abstract::ShapePtr & compile_shape,const abstract::ShapePtr & args_shape,const std::string & target_str,size_t index)449 void CheckShapeConsistency(const abstract::ShapePtr &compile_shape, const abstract::ShapePtr &args_shape,
450                            const std::string &target_str, size_t index) {
451   MS_EXCEPTION_IF_NULL(compile_shape);
452   MS_EXCEPTION_IF_NULL(args_shape);
453   if (*compile_shape == *args_shape) {
454     return;
455   }
456 
457   auto compile_shape_vec = compile_shape->shape();
458   auto args_shape_vec = args_shape->shape();
459 
460   if (!IsDynamicRank(compile_shape_vec)) {
461     if (!args_shape_vec.empty() && compile_shape_vec.size() != args_shape_vec.size()) {
462       MS_EXCEPTION(ValueError) << "For " << target_str << " and tuple(list) in " << target_str << ", the dims of "
463                                << index + 1 << "th input must be the same as expected, "
464                                << "but got expected: " << compile_shape_vec.size()
465                                << ", and input: " << args_shape_vec.size() << "!";
466     }
467 
468     for (size_t i = 0; i < compile_shape_vec.size(); ++i) {
469       if (compile_shape_vec[i] == abstract::Shape::kShapeDimAny || compile_shape_vec[i] == args_shape_vec[i]) {
470         continue;
471       }
472       MS_EXCEPTION(ValueError) << "For " << target_str << " and tuple(list) in " << target_str << ", the shape of "
473                                << index + 1 << "th input must be the same as expected, "
474                                << "but got expected: " << compile_shape_vec[i] << ", and input: " << args_shape_vec[i]
475                                << "!";
476     }
477   }
478 }
479 
CheckSizeConsistency(const AbstractBasePtrList & compile_abstracts,const AbstractBasePtrList & args_abstracts,const std::string & target_str,bool dynamic_len=false)480 inline void CheckSizeConsistency(const AbstractBasePtrList &compile_abstracts,
481                                  const AbstractBasePtrList &args_abstracts, const std::string &target_str,
482                                  bool dynamic_len = false) {
483   if (!dynamic_len && compile_abstracts.size() != args_abstracts.size()) {
484     MS_EXCEPTION(ValueError) << "For " << target_str << " and tuple(list) in " << target_str
485                              << ", the length of input must be equal to expected one, but got expected: "
486                              << compile_abstracts.size() << " and input: " << args_abstracts.size() << "!";
487   }
488   if (dynamic_len && compile_abstracts.empty()) {
489     MS_LOG(INTERNAL_EXCEPTION) << "For " << target_str << ", the dynamic_len compile arguments should not be empty!";
490   }
491 }
492 
CheckAbstractConsistency(const AbstractBasePtrList & compile_abstracts,const AbstractBasePtrList & args_abstracts,const std::string & target_str,bool dynamic_len=false)493 void CheckAbstractConsistency(const AbstractBasePtrList &compile_abstracts, const AbstractBasePtrList &args_abstracts,
494                               const std::string &target_str, bool dynamic_len = false) {
495   CheckSizeConsistency(compile_abstracts, args_abstracts, target_str, dynamic_len);
496   for (size_t i = 0; i < args_abstracts.size(); ++i) {
497     auto compile_abs = dynamic_len ? compile_abstracts[0] : compile_abstracts[i];
498     auto args_abs = args_abstracts[i];
499     auto is_compile_var = compile_abs->BuildValue()->ContainsValueAny();
500     auto is_args_var = args_abs->BuildValue()->ContainsValueAny();
501     if (is_compile_var != is_args_var) {
502       MS_EXCEPTION(TypeError) << "For " << target_str << " or tuple(list) in " << target_str << ", the " << i + 1
503                               << "th should be " << (is_compile_var ? "mutable" : "static") << " one, but got "
504                               << (is_args_var ? "mutable" : "static") << "!";
505     }
506 
507     if (is_compile_var) {
508       if (compile_abs->isa<abstract::AbstractTensor>() && args_abs->isa<abstract::AbstractTensor>()) {
509         auto compile_tensor = compile_abs->cast<abstract::AbstractTensorPtr>();
510         auto args_tensor = args_abs->cast<abstract::AbstractTensorPtr>();
511 
512         // Check shape's consistency.
513         auto compile_shape = compile_tensor->shape();
514         auto args_shape = args_tensor->shape();
515         CheckShapeConsistency(compile_shape, args_shape, target_str, i);
516 
517         auto compile_element = compile_tensor->element();
518         auto args_element = args_tensor->element();
519         if (!common::IsEqual(compile_element, args_element)) {
520           MS_EXCEPTION(TypeError) << "For " << target_str << " or tuple(list) in " << target_str << ", the " << i + 1
521                                   << "th type should be " << compile_tensor->BuildType()->ToString() << ", but got "
522                                   << args_tensor->BuildType()->ToString() << "!";
523         }
524       } else if (compile_abs->isa<abstract::AbstractSequence>() && args_abs->isa<abstract::AbstractSequence>()) {
525         auto compile_sequence = compile_abs->cast<abstract::AbstractSequencePtr>();
526         auto args_sequence = args_abs->cast<abstract::AbstractSequencePtr>();
527         CheckAbstractConsistency(compile_sequence->elements(), args_sequence->elements(), target_str,
528                                  compile_sequence->dynamic_len());
529       } else {
530         if (!common::IsEqual(compile_abs, args_abs)) {
531           MS_EXCEPTION(ValueError) << "For " << target_str << " or tuple(list) in " << target_str << ", the " << i + 1
532                                    << "th should be" << compile_abs->ToString() << ", but got " << args_abs->ToString()
533                                    << "!";
534         }
535       }
536     } else if (compile_abs->isa<abstract::AbstractList>() && args_abs->isa<abstract::AbstractList>()) {
537       auto compile_sequence = compile_abs->cast<abstract::AbstractSequencePtr>();
538       auto args_sequence = args_abs->cast<abstract::AbstractSequencePtr>();
539       CheckAbstractConsistency(compile_sequence->elements(), args_sequence->elements(), target_str);
540     } else {
541       if (!common::IsEqual(compile_abs, args_abs)) {
542         MS_EXCEPTION(ValueError) << "For " << target_str << " or tuple(list) in " << target_str << ", the " << i + 1
543                                  << "th should be" << compile_abs->ToString() << ", but got " << args_abs->ToString()
544                                  << "!";
545       }
546     }
547   }
548 }
549 }  // namespace
550 
GetObjDesc(const py::object & source)551 std::string GetObjDesc(const py::object &source) {
552   std::string obj_desc;
553   if (py::hasattr(source, parse::PYTHON_PARSE_METHOD)) {
554     auto cell_class_name = source.attr("__class__").attr("__name__");
555     auto jit_name = source.attr(parse::PYTHON_PARSE_METHOD);
556     obj_desc = "'" + py::cast<std::string>(cell_class_name) + "." + py::cast<std::string>(jit_name) + "'";
557   } else {
558     if (py::hasattr(source, "__name__")) {
559       auto jit_name = source.attr("__name__");
560       obj_desc = "'" + py::cast<std::string>(jit_name) + "'";
561     } else if (py::isinstance<Cell>(source)) {
562       auto cell_class_name = source.attr("__class__").attr("__name__");
563       obj_desc = "'" + py::cast<std::string>(cell_class_name) + ".construct'";
564     } else {
565       MS_EXCEPTION(TypeError) << "The source object is invalid: " << py::str(source);
566     }
567   }
568   return obj_desc;
569 }
570 
CheckArgsValid(const py::object & source,const py::tuple & args)571 void CheckArgsValid(const py::object &source, const py::tuple &args) {
572   if (!IS_OUTPUT_ON(mindspore::kInfo)) {
573     return;
574   }
575   for (size_t i = 0; i < args.size(); i++) {
576     if (!CheckArgValid(args[i])) {
577       MS_LOG(INFO) << "The " << ToOrdinal(i + 1) << " arg type is " << args[i].get_type() << ", value is '"
578                    << py::str(args[i]) << "'.";
579     }
580   }
581 }
582 
CheckArgumentsConsistency(const py::tuple & compile_args,const py::tuple & args_list,const py::object & target)583 void GraphExecutorPy::CheckArgumentsConsistency(const py::tuple &compile_args, const py::tuple &args_list,
584                                                 const py::object &target) {
585   if ((!py::isinstance<py::str>(target))) {
586     MS_EXCEPTION(TypeError) << "The `target` must be string!";
587   }
588   std::string target_str = py::cast<std::string>(target);
589   if (compile_args.size() != args_list.size()) {
590     MS_EXCEPTION(ValueError) << "For " << target_str
591                              << ", the length of input must be equal to expected one, but got expected: "
592                              << compile_args.size() << " and input: " << args_list.size() << "!";
593   }
594 
595   AbstractBasePtrList compile_abstracts;
596   compile_abstracts.reserve(compile_args.size());
597   AbstractBasePtrList args_abstracts;
598   args_abstracts.reserve(compile_args.size());
599   for (size_t i = 0; i < compile_args.size(); ++i) {
600     ValuePtr compile_args_converted = nullptr;
601     if (!parse::ConvertData(compile_args[i], &compile_args_converted)) {
602       MS_LOG(INTERNAL_EXCEPTION) << "ConvertData for " << i << "th compiling argument failed, the argument type is "
603                                  << compile_args[i].get_type() << ", value is '" << py::str(compile_args[i]) << "'.";
604     }
605     compile_abstracts.push_back(ArgsToAbstract(compile_args[i], compile_args_converted));
606 
607     ValuePtr args_converted = nullptr;
608     if (!parse::ConvertData(args_list[i], &args_converted)) {
609       MS_LOG(INTERNAL_EXCEPTION) << "ConvertData for " << i << "th input argument failed, the argument type is "
610                                  << args_list[i].get_type() << ", value is '" << py::str(args_list[i]) << "'.";
611     }
612     args_abstracts.push_back(ArgsToAbstract(args_list[i], args_converted));
613   }
614 
615   CheckAbstractConsistency(compile_abstracts, args_abstracts, target_str, false);
616 }
617 
GenerateArgumentsKey(const py::object & obj,const py::tuple & args,const py::dict & kwargs,bool enable_tuple_broaden)618 py::object GraphExecutorPy::GenerateArgumentsKey(const py::object &obj, const py::tuple &args, const py::dict &kwargs,
619                                                  bool enable_tuple_broaden) {
620   MS_LOG(DEBUG) << "GenerateArgumentsKey args size: " << args.size()
621                 << ", enable_tuple_broaden: " << enable_tuple_broaden;
622 
623   abstract::AbstractBasePtrList args_abs;
624   ClearCurConvertInput();
625   for (std::size_t i = 0; i < args.size(); i++) {
626     ValuePtr converted = nullptr;
627     if (!parse::ConvertData(args[i], &converted)) {
628       MS_LOG(INTERNAL_EXCEPTION) << "ConvertData for " << i << "th argument failed, the argument type is "
629                                  << args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
630     }
631     AbstractBasePtr abs = ArgsToAbstract(args[i], converted, enable_tuple_broaden);
632     (void)args_abs.emplace_back(abs);
633     // The 'converted' maybe a Parameter, we need connect it to the Parameter of func graph,
634     // so we keep all inputs for subsequent procedure.
635     (void)cur_convert_input_.emplace(args[i].ptr(), std::make_pair(converted, abs));
636   }
637   for (const auto &item : kwargs) {
638     ValuePtr key = nullptr;
639     ValuePtr value = nullptr;
640     bool success = parse::ConvertData(py::cast<py::object>(item.first), &key) &&
641                    parse::ConvertData(py::cast<py::object>(item.second), &value);
642     if (!success) {
643       MS_LOG(INTERNAL_EXCEPTION) << "ConvertData for argument (" << py::str(item.first) << ": " << py::str(item.second)
644                                  << ") failed.";
645     }
646     AbstractBasePtr value_abs = ArgsToAbstract(py::cast<py::object>(item.second), value, enable_tuple_broaden);
647     auto keyword_arg_abs = std::make_shared<abstract::AbstractKeywordArg>(GetValue<std::string>(key), value_abs);
648     (void)args_abs.emplace_back(keyword_arg_abs);
649     (void)cur_convert_input_.emplace(item.first.ptr(), std::make_pair(value, keyword_arg_abs));
650   }
651 
652   // If cache matched no need CheckArgsValid
653   auto iter = kArgsCache.find(args_abs);
654   if (iter != kArgsCache.end()) {
655     return py::int_(iter->second);
656   }
657 
658   static uint64_t key_counter = 0;
659   kArgsCache[args_abs] = key_counter;
660   kCellArgsMap[obj.ptr()] = args_abs;
661   MS_LOG(INFO) << "Generate a new compile key for new args, key: " << key_counter;
662   if (IS_OUTPUT_ON(mindspore::kInfo)) {
663     std::ostringstream buffer;
664     buffer << "New cached args:"
665            << "\n";
666     for (size_t i = 0; i < args_abs.size(); ++i) {
667       buffer << "Arg[" << i << "]: " << args_abs[i]->ToString() << "\n";
668     }
669     MS_LOG(INFO) << buffer.str();
670   }
671   return py::int_(key_counter++);
672 }
673 
ClearCompileArgumentsResource()674 void GraphExecutorPy::ClearCompileArgumentsResource() {
675   // Clear global converted args saved in GenerateArgumentsKey.
676   ClearCurConvertInput();
677 }
678 
ClearArgCache(const py::object & obj)679 void ClearArgCache(const py::object &obj) {
680   if (py::isinstance<py::none>(obj)) {
681     return;
682   }
683   auto iter = kCellArgsMap.find(obj.ptr());
684   if (iter != kCellArgsMap.end()) {
685     (void)kArgsCache.erase(iter->second);
686     (void)kCellArgsMap.erase(iter);
687   }
688 }
689 
ClearCurConvertInput()690 void GraphExecutorPy::ClearCurConvertInput() { cur_convert_input_.clear(); }
691 
ParentBeforeFork()692 void GraphExecutorPy::ParentBeforeFork() {
693   MS_LOG(DEBUG) << "GraphExecutorPy prepare before fork.";
694   MS_LOG(DEBUG) << "Stop AnalysisSchedule tasks.";
695   abstract::AnalysisSchedule::GetInstance().Stop();
696   MS_LOG(DEBUG) << "GraphExecutorPy prepare before fork done.";
697 }
698 
ParentAfterFork()699 void GraphExecutorPy::ParentAfterFork() {
700   MS_LOG(DEBUG) << "GraphExecutorPy in parent process reinitialize after fork.";
701   MS_LOG(DEBUG) << "Restart AnalysisSchedule tasks.";
702   abstract::AnalysisSchedule::GetInstance().Start();
703   MS_LOG(DEBUG) << "GraphExecutorPy in parent process reinitialize after fork done.";
704 }
705 
ChildAfterFork()706 void GraphExecutorPy::ChildAfterFork() {
707   MS_LOG(DEBUG) << "GraphExecutorPy in child process reinitialize after fork.";
708   MS_LOG(DEBUG) << "Restart AnalysisSchedule tasks.";
709   abstract::AnalysisSchedule::GetInstance().Start();
710   MS_LOG(DEBUG) << "GraphExecutorPy in child process reinitialize after fork done.";
711 }
712 
VerifyInputSignature(const py::list & input_signature,const py::tuple & inputs)713 py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple &inputs) {
714   MS_LOG(DEBUG) << "Verify args size:" << inputs.size();
715   if (inputs.size() != input_signature.size()) {
716     MS_LOG(ERROR) << "Signature size not equal to args size";
717     return false;
718   }
719 
720   size_t count = 0;
721   for (auto arg_obj : inputs) {
722     std::shared_ptr<Tensor> m_tensor = nullptr;
723     bool is_tensor = false;
724     if (py::isinstance<Tensor>(arg_obj)) {
725       m_tensor = arg_obj.cast<std::shared_ptr<Tensor>>();
726       is_tensor = true;
727     } else if (IsStubTensor(arg_obj)) {
728       m_tensor = ConvertStubTensor(arg_obj);
729       is_tensor = true;
730     }
731     if (is_tensor && m_tensor == nullptr) {
732       MS_LOG(ERROR) << "Verify Tensor error, get ptr is null";
733       return false;
734     }
735 
736     if (m_tensor != nullptr) {
737       MS_LOG(DEBUG) << "Verify Tensor";
738       auto sig = input_signature[count].cast<std::shared_ptr<MetaTensor>>();
739       ShapeVector sig_shape = sig->shape();
740       TypePtr sig_type = sig->Dtype();
741 
742       ShapeVector tensor_shape = m_tensor->shape_c();
743       if (tensor_shape != sig_shape) {
744         MS_LOG(ERROR) << "Python input shape is incompatible with input_signature";
745         return false;
746       }
747 
748       if (*m_tensor->Dtype() != *sig_type) {
749         MS_LOG(ERROR) << "Python input type(" << m_tensor->Dtype()->ToString() << ") incompatible with input_signature("
750                       << sig_type->ToString() << ")";
751         return false;
752       }
753     }
754     count++;
755   }
756 
757   return true;
758 }
759 
GetResource(const std::string & phase)760 ResourcePtr GraphExecutorPy::GetResource(const std::string &phase) {
761   MS_LOG(DEBUG) << "Phase size:" << info_.size();
762   if (info_.count(phase) == 0) {
763     return nullptr;
764   }
765   return info_[phase]->resource;
766 }
767 
GetFuncGraph(const std::string & phase)768 FuncGraphPtr GraphExecutorPy::GetFuncGraph(const std::string &phase) {
769   const auto it = info_.find(phase);
770   if (it == info_.end()) {
771     MS_LOG(INFO) << "No executor info. found for phase: " << phase;
772     return nullptr;
773   }
774   return it->second->func_graph;
775 }
776 
SetJitPrimalFuncGraph(const FuncGraphPtr & primal_func_graph,const std::string & phase)777 void GraphExecutorPy::SetJitPrimalFuncGraph(const FuncGraphPtr &primal_func_graph, const std::string &phase) {
778   const auto it = info_.find(phase);
779   if (it == info_.end()) {
780     MS_LOG(INTERNAL_EXCEPTION) << "No executor info. found for phase: " << phase;
781     return;
782   }
783   MS_EXCEPTION_IF_NULL(primal_func_graph);
784   it->second->jit_primal_func_graph = primal_func_graph;
785 }
786 
GetJitPrimalFuncGraph(const std::string & phase)787 FuncGraphPtr GraphExecutorPy::GetJitPrimalFuncGraph(const std::string &phase) {
788   const auto it = info_.find(phase);
789   if (it == info_.end()) {
790     MS_LOG(INFO) << "No executor info. found for phase: " << phase;
791     return nullptr;
792   }
793   return it->second->jit_primal_func_graph;
794 }
795 
GetJitGradGraph(const std::string & phase)796 FuncGraphPtr GraphExecutorPy::GetJitGradGraph(const std::string &phase) {
797   const auto it = info_.find(phase);
798   if (it == info_.end()) {
799     MS_LOG(INFO) << "No executor info. found for phase: " << phase;
800     return nullptr;
801   }
802   return it->second->jit_grad_graph;
803 }
804 
SetJitGradGraph(const FuncGraphPtr & grad_graph,const std::string & phase)805 void GraphExecutorPy::SetJitGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase) {
806   const auto it = info_.find(phase);
807   if (it == info_.end()) {
808     MS_LOG(INTERNAL_EXCEPTION) << "No executor info. found for phase: " << phase;
809     return;
810   }
811   if (it->second->jit_grad_graph != nullptr) {
812     MS_LOG(DEBUG) << "The grad graph has existed, phase is: " << phase;
813   }
814   MS_EXCEPTION_IF_NULL(grad_graph);
815   it->second->jit_grad_graph = grad_graph;
816 }
817 
GetVmEvalFunc(const std::string & phase)818 compile::VmEvalFuncPtr GraphExecutorPy::GetVmEvalFunc(const std::string &phase) {
819   ResourcePtr res = GetResource(phase);
820   MS_EXCEPTION_IF_NULL(res);
821   if (res->HasResult(kOutput) && res->GetResult(kOutput).is<compile::VmEvalFuncPtr>()) {
822     return res->GetResult(kOutput).cast<compile::VmEvalFuncPtr>();
823   }
824   MS_LOG(ERROR) << "GetVmEvalFunc vm model can't find kOutput:" << kOutput;
825   return nullptr;
826 }
827 
HasCompiled(const std::string & phase) const828 bool GraphExecutorPy::HasCompiled(const std::string &phase) const { return info_.count(phase) != 0; }
829 
GetFuncGraphProto(const std::string & phase,const std::string & ir_type,const bool & incremental)830 py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type,
831                                              const bool &incremental) {
832   FuncGraphPtr fg_ptr = GetFuncGraph(phase);
833   if (fg_ptr == nullptr) {
834     for (const auto &item : info_) {
835       MS_LOG(DEBUG) << "Phase key is: " << item.first;
836     }
837     MS_LOG(EXCEPTION) << "Can not find func graph " << phase;
838   }
839 
840   if (ir_type == IR_TYPE_ANF) {
841     std::string proto_str = GetFuncGraphProtoString(fg_ptr);
842     if (proto_str.empty()) {
843       MS_LOG(EXCEPTION) << "Export ANF format model failed.";
844     }
845     return proto_str;
846   }
847 
848   if (ir_type == IR_TYPE_ONNX) {
849     std::string proto_str = GetOnnxProtoString(fg_ptr);
850     if (proto_str.empty()) {
851       MS_LOG(EXCEPTION) << "Export ONNX format model failed.";
852     }
853     return proto_str;
854   }
855 
856   if (ir_type == IR_TYPE_MINDIR) {
857     // obfuscate model
858     std::string proto_str = GetBinaryProtoString(fg_ptr, incremental);
859     if (proto_str.empty()) {
860       MS_LOG(EXCEPTION) << "Export MINDIR format model failed.";
861     }
862     return proto_str;
863   }
864 
865   MS_LOG(INTERNAL_EXCEPTION) << "Unknown ir type: " << ir_type;
866 }
867 
GetObfuscateFuncGraphProto(const std::string & phase,const bool & incremental,const float obf_ratio,const int branch_control_input)868 py::bytes GraphExecutorPy::GetObfuscateFuncGraphProto(const std::string &phase, const bool &incremental,
869                                                       const float obf_ratio, const int branch_control_input) {
870   FuncGraphPtr fg_ptr = GetFuncGraph(phase);
871   // obfuscate model
872   if (branch_control_input == 0) {
873     (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
874     MS_LOG(DEBUG) << "[GetObfuscateFuncGraphProto] set customized function names finished";
875   }
876   mindspore::DynamicObfuscator dynamic_obfuscator(obf_ratio, branch_control_input);
877   mindspore::FuncGraphPtr obfuscated_graph = dynamic_obfuscator.ObfuscateMindIR(fg_ptr);
878 
879   std::string proto_str = GetBinaryProtoString(obfuscated_graph, incremental);
880   if (proto_str.empty()) {
881     MS_LOG(EXCEPTION) << "GetBinaryProtoString failed.";
882   }
883   return proto_str;
884 }
885 
GetOptimizeGraphProto(const std::string & phase)886 py::bytes GraphExecutorPy::GetOptimizeGraphProto(const std::string &phase) {
887   if (info_.count(phase) == 0) {
888     MS_LOG(INTERNAL_EXCEPTION) << "No phase in executor: " << phase;
889   }
890   FuncGraphPtr fg_ptr = info_[phase]->resource->optimize_graph();
891   if (fg_ptr == nullptr) {
892     MS_LOG(WARNING) << "Can not find optimize graph.";
893     return "";
894   }
895   std::string proto_str = GetFuncGraphProtoString(fg_ptr);
896   if (proto_str.empty()) {
897     MS_LOG(EXCEPTION) << "Export optimize graph proto string failed.";
898   }
899   return proto_str;
900 }
901 
SetJitConfig(const py::dict & config)902 void GraphExecutorPy::SetJitConfig(const py::dict &config) {
903   auto jit_config = GenerateJitConfigMap(config);
904   PhaseManager::GetInstance().set_jit_config(jit_config);
905 }
906 
GetParallelGraphInfo(const std::string & phase)907 py::dict GraphExecutorPy::GetParallelGraphInfo(const std::string &phase) {
908   MS_LOG(DEBUG) << "GetParallelGraphInfo!";
909   std::string parallel_phase = phase + kStepParallelGraph;
910   auto graph = GetFuncGraph(parallel_phase);
911   if (graph == nullptr) {
912     MS_LOG(INTERNAL_EXCEPTION) << "Can not access FuncGraph according to phase: " << parallel_phase;
913   }
914 
915   return mindspore::parallel::GetParallelCNodeInfoFromGraph(graph);
916 }
917 
GetParameterLayout(const std::string & phase)918 py::dict GraphExecutorPy::GetParameterLayout(const std::string &phase) {
919   MS_LOG(DEBUG) << "GetParameterLayout!";
920   std::string layout_graph = phase + kStepParallelGraph;
921   auto graph = GetFuncGraph(layout_graph);
922   if (graph == nullptr) {
923     auto resource = info_[phase]->resource;
924     return mindspore::parallel::GetParameterLayoutFromResource(resource);
925   }
926   return mindspore::parallel::GetParameterLayoutFromGraph(graph);
927 }
928 
FlopsCollection(const std::string & phase)929 py::tuple GraphExecutorPy::FlopsCollection(const std::string &phase) {
930   auto graph = GetFuncGraph(phase);
931   return mindspore::parallel::FlopsCollection(graph);
932 }
933 
GetCNodeStrategy(const std::string & phase)934 py::dict GraphExecutorPy::GetCNodeStrategy(const std::string &phase) {
935   MS_LOG(DEBUG) << "GetCNodeStrategy!";
936   return stra_dict_[phase];
937 }
938 
GetParallelParameterNameList(const std::string & phase)939 py::list GraphExecutorPy::GetParallelParameterNameList(const std::string &phase) {
940   std::string param_graph = phase + kStepParallelGraph;
941   auto graph = GetFuncGraph(param_graph);
942   if (graph == nullptr) {
943     auto resource = info_[phase]->resource;
944     return mindspore::parallel::GetParallelParameterNameListFromResource(resource);
945   }
946   return mindspore::parallel::GetParallelParameterNameListFromGraph(graph);
947 }
948 
SetCNodeStrategy(const std::string & name,const parallel::Strategies & strategy)949 void GraphExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategies &strategy) {
950   MS_LOG(DEBUG) << "SetCNodeStrategy!";
951   stra_dict_[phase_][py::str(name)] = strategy;
952 }
953 
GetNumOpsInfo(const std::string & phase)954 size_t GraphExecutorPy::GetNumOpsInfo(const std::string &phase) {
955   MS_LOG(DEBUG) << "GetNumOpsInfo!";
956   return phase_to_num_op_info_[phase];
957 }
958 
SetNumOpsInfo(size_t num_ops)959 void GraphExecutorPy::SetNumOpsInfo(size_t num_ops) {
960   MS_LOG(DEBUG) << "SetNumOpsInfo!";
961   phase_to_num_op_info_[phase_] = num_ops;
962 }
963 
GetAllreduceFusion(const std::string & phase)964 py::dict GraphExecutorPy::GetAllreduceFusion(const std::string &phase) {
965   MS_LOG(INFO) << "GetAllreduceFusion!";
966   auto graph = GetFuncGraph(phase);
967   return mindspore::parallel::GetAllreduceFusion(graph);
968 }
969 
970 // Not support multi thread, not support nested call too.
971 // Here using nested_called flg to avoid nested call.
DelNetRes(const py::object & source,const py::set & id)972 void GraphExecutorPy::DelNetRes(const py::object &source, const py::set &id) {
973   ClearArgCache(source);
974   // Del all graphs by different phase
975   for (auto item : id) {
976     DelOneNetRes(item);
977   }
978 }
979 
DelOneNetRes(const py::handle & py_phase)980 void GraphExecutorPy::DelOneNetRes(const py::handle &py_phase) {
981   if (!pybind11::isinstance<py::str>(py_phase)) {
982     MS_LOG(ERROR) << "Expect string phase, but got " << py::str(py_phase);
983     return;
984   }
985   auto phase = pybind11::cast<std::string>(py_phase);
986   MS_LOG(INFO) << "Delete one net resource start, phase: " << phase;
987   auto iter = info_.find(phase);
988   auto clear = false;
989   if (iter != info_.end()) {
990     clear = true;
991     auto res = iter->second->resource;
992     if (res->HasResult(kStepParallelGraph)) {
993       std::string layout_graph = phase + kStepParallelGraph;
994       (void)info_.erase(layout_graph);
995     }
996     (void)info_.erase(phase);
997     MS_LOG(DEBUG) << "Delete phase: " << phase << ", info size: " << info_.size();
998   }
999   if (clear) {
1000     // Do clear here to avoid any pointer for resource.
1001     FuncGraphLoopBreaker::Inst().ClearCellGraphs(phase);
1002     FuncGraphLoopBreaker::Inst().CleanUnusedFuncGraphs(phase);
1003   }
1004   MS_LOG(INFO) << "Delete one net resource end. " << clear;
1005 }
1006 
ClearRes()1007 void GraphExecutorPy::ClearRes() {
1008   MS_LOG(INFO) << "Clean executor resource!";
1009   executor_ = nullptr;
1010 }
1011 
get_queue_name(const std::string & dataset_phase)1012 std::string GraphExecutorPy::get_queue_name(const std::string &dataset_phase) {
1013   return CompileCacheManager::GetCachedDataQueueName(dataset_phase);
1014 }
1015 
~GraphExecutorPy()1016 GraphExecutorPy::~GraphExecutorPy() {
1017   MS_LOG(INFO) << "Release Executor!";
1018   ConfigManager::GetInstance().ResetConfig();
1019 }
1020 
SaveCompiledGraph(const std::string & phase)1021 void GraphExecutorPy::SaveCompiledGraph(const std::string &phase) {
1022   // save the graph to GraphExecutorPy
1023   FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
1024   MS_EXCEPTION_IF_NULL(func_graph);
1025   MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase << ")!";
1026   info_[phase]->func_graph = func_graph;
1027   func_graph->set_attr("phase", MakeValue(GetPhasePrefix(phase)));
1028 
1029   if ((func_graph != nullptr) && parallel::IsAutoParallelCareGraph(func_graph)) {
1030     MS_LOG(DEBUG) << "Save model parallel parameter layout graph!";
1031     auto res = info_[phase]->resource;
1032     // When using frontend compile cache, model parallel parameter layout graph is not saved.
1033     if (res->HasResult(kStepParallelGraph)) {
1034       func_graph = res->GetResult(kStepParallelGraph).cast<FuncGraphPtr>();
1035       ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
1036       std::string layout_graph = phase + kStepParallelGraph;
1037       executor_info->func_graph = func_graph;
1038       info_[layout_graph] = executor_info;
1039     }
1040   } else {
1041     MS_LOG(DEBUG) << "Save model parallel parameter layout graph null!";
1042   }
1043   MS_LOG(INFO) << "End save compiled func graph!";
1044 }
1045 
GetGeBackendPolicy() const1046 void GraphExecutorPy::GetGeBackendPolicy() const {
1047   auto ms_context = MsContext::GetInstance();
1048   MS_EXCEPTION_IF_NULL(ms_context);
1049   std::string backend = ms_context->backend_policy();
1050   if (backend != "ge") {
1051     MS_LOG(EXCEPTION) << backend << " backend policy is not supported under ge backend!";
1052   }
1053 }
1054 
IsPhaseExportAir(const std::string & phase)1055 bool IsPhaseExportAir(const std::string &phase) {
1056   auto phase_to_export = "export.air";
1057   return phase.rfind(phase_to_export) != std::string::npos;
1058 }
1059 
IsPhaseExport(const std::string & phase)1060 bool IsPhaseExport(const std::string &phase) {
1061   constexpr auto export_str = "export";
1062   return phase.compare(0, strlen(export_str), export_str) == 0;
1063 }
1064 
IsPhaseTrain(const std::string & phase)1065 bool IsPhaseTrain(const std::string &phase) {
1066   const std::string phase_to_train = "train";
1067   return phase.rfind(phase_to_train) != std::string::npos;
1068 }
1069 
IsPhaseLoadFromMindIR(const std::string & phase)1070 bool IsPhaseLoadFromMindIR(const std::string &phase) {
1071   const std::string mindir_graph = "graph_load_from_mindir";
1072   return phase.rfind(mindir_graph) != std::string::npos;
1073 }
1074 
GetActions(const ResourcePtr & resource,const std::string & phase,bool use_vm,bool trace_flag=false,bool erase_parse=false)1075 std::vector<ActionItem> GetActions(const ResourcePtr &resource, const std::string &phase, bool use_vm,
1076                                    bool trace_flag = false, bool erase_parse = false) {
1077   MS_EXCEPTION_IF_NULL(resource);
1078   compile::SetMindRTEnable();
1079   return VmPipeline(resource, trace_flag, erase_parse);
1080 }
1081 
InitCompileCacheInfo(const ResourcePtr & resource,const std::string & phase)1082 void GraphExecutorPy::InitCompileCacheInfo(const ResourcePtr &resource, const std::string &phase) {
1083   // The compilation cache only support for training cell or functions decorated with 'jit' currently.
1084   // If enable compilation cache, it will get a non-empty dependent files list from python.
1085   if (!CompileCacheEnable()) {
1086     return;
1087   }
1088   bool has_python_script = true;
1089   if (compile_cache_dep_files_.empty()) {
1090     has_python_script = false;
1091   }
1092 
1093   {
1094     MsProfileStatGuard stat_guard("LoadCachedFuncGraph");
1095     static size_t idx = 0;
1096     MS_EXCEPTION_IF_NULL(resource);
1097     resource->GetCompileCacheResource(compile_cache_dep_files_, weights_, queue_name_, idx++,
1098                                       &compile_cache_consistent_, has_python_script);
1099   }
1100 }
1101 
ParallelPostProcess(const std::string & phase,bool use_compile_cache)1102 void GraphExecutorPy::ParallelPostProcess(const std::string &phase, bool use_compile_cache) {
1103   // Slice Python parameter obj
1104   auto layout_graph = phase + kStepParallelGraph;
1105   // only Parallel graph has tensor_layout
1106   auto root = GetFuncGraph(layout_graph);
1107   bool after_shard = false;
1108   if (phase.find("after_shard") != std::string::npos) {
1109     after_shard = true;
1110   }
1111   // Use compile cache
1112   if (use_compile_cache) {
1113     parallel::InitCompileCacheParams(info_[phase]->resource);
1114     return;
1115   }
1116   // Initialize parameters for graph which auto-parallel not care.
1117   if (root == nullptr && !after_shard) {
1118     auto graph = info_[phase]->resource->func_graph();
1119     MS_EXCEPTION_IF_NULL(graph);
1120     parallel::InitPynativeNoShardParams(graph);
1121     return;
1122   }
1123   MS_EXCEPTION_IF_NULL(root);
1124   parallel::AutoParallelPostProcess(root);
1125 }
1126 
1127 // Clean all resource not used in the future and cache generated during compiling.
CleanCompileRes(const ResourcePtr & resource)1128 void GraphExecutorPy::CleanCompileRes(const ResourcePtr &resource) {
1129   MS_LOG(INFO) << "Clean compile resource start";
1130   ProcessStatus::GetInstance().RecordStart(kPipelineClean);
1131   (void)profiler::CollectHostInfo(kCompiler, kPipelineClean, kPipelineClean, 0, 0, 0);
1132   abstract::AnalysisContext::ClearContext();
1133   ClearCompileArgumentsResource();
1134   ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
1135   ad::g_k_prims.clear();
1136   ad::DFunctor::Clear();
1137   ReclaimOptimizer();
1138   resource->Clean();
1139   auto parallel_context = parallel::ParallelContext::GetInstance();
1140   MS_EXCEPTION_IF_NULL(parallel_context);
1141   if (parallel_context->hccl_test_available()) {
1142     parallel::g_device_manager = nullptr;
1143   }
1144   FuncGraphLoopBreaker::Inst().CleanMetaFuncGraphs();
1145   (void)profiler::CollectHostInfo(kCompiler, kPipelineClean, kPipelineClean, 0, 0, 1);
1146   ProcessStatus::GetInstance().RecordEnd();
1147   CompileCacheContext::GetInstance().Clear();
1148   parse::Parser::CleanParserResource();
1149   MS_LOG(INFO) << "Clean compile resource end";
1150 }
1151 
CompileInner(const FuncGraphPtr & graph,const py::tuple & args,const py::dict & kwargs,const std::string & phase,bool use_vm,bool trace_flag)1152 bool GraphExecutorPy::CompileInner(const FuncGraphPtr &graph, const py::tuple &args, const py::dict &kwargs,
1153                                    const std::string &phase, bool use_vm, bool trace_flag) {
1154   auto ms_context = MsContext::GetInstance();
1155   MS_EXCEPTION_IF_NULL(ms_context);
1156   ms_context->SetCellReuseLevel(CellReuseLevel::kNoCellReuse);
1157   PhaseManager::GetInstance().set_phase(phase);
1158   phase_ = phase;
1159 
1160   ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
1161   ResourcePtr resource = std::make_shared<Resource>();
1162   resource->set_func_graph(graph);
1163   InitCompileCacheInfo(resource, phase);
1164   bool use_compile_cache = resource->EnableCompileCache() && resource->func_graph();
1165   ConfigManager::GetInstance().ResetQueue(queue_name_);
1166 
1167   bool erase_parse = true;
1168   auto actions = GetActions(resource, phase, use_vm, trace_flag, erase_parse);
1169   std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, actions);
1170 
1171   if (pip->NeedCreateBackend()) {
1172     // Create backend asynchronously.
1173     resource->SetBackendAsync([]() {
1174       auto backend = compile::CreateBackend();
1175 #ifdef ENABLE_DEBUGGER
1176       // Connect session to debugger.
1177       backend->SetDebugger();
1178 #endif
1179       return backend;
1180     });
1181   }
1182 
1183   // Get the parameters items and add the value to args_abs.
1184   abstract::AbstractBasePtrList args_abs;
1185   std::vector<ValuePtr> arguments;
1186   MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
1187   bool is_auto_parallel = (parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kSemiAutoParallel ||
1188                            parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kAutoParallel);
1189   ConvertArgs(args, kwargs, is_auto_parallel, &args_abs, &arguments);
1190   ConvertSymbolicShape(args, &args_abs);
1191   AddManagerForFuncGraphArgs(resource, arguments);
1192   resource->set_arguments(arguments);
1193   resource->set_args_abs(args_abs);
1194   executor_info->arg_list_size = args.size() + kwargs.size();
1195   executor_info->resource = resource;
1196   info_[phase] = executor_info;
1197   pip->Run();
1198 
1199   // Save the compiled graph to MsPipeLine.
1200   SaveCompiledGraph(phase);
1201   if (is_auto_parallel) {
1202     ParallelPostProcess(phase, use_compile_cache);
1203   }
1204 #ifdef ENABLE_DUMP_IR
1205   mindspore::RDR::Snapshot();
1206 #endif
1207   CleanCompileRes(resource);
1208   PhaseManager::GetInstance().ClearPhase();
1209   MS_LOG(INFO) << "Finish compiling.";
1210   return true;
1211 }
1212 
CompileInner(const py::object & source,const py::tuple & args,const py::dict & kwargs,const py::object & phase,bool use_vm)1213 bool GraphExecutorPy::CompileInner(const py::object &source, const py::tuple &args, const py::dict &kwargs,
1214                                    const py::object &phase, bool use_vm) {
1215   auto ms_context = MsContext::GetInstance();
1216   MS_EXCEPTION_IF_NULL(ms_context);
1217   ms_context->SetCellReuseLevel(CellReuseLevel::kNoCellReuse);
1218   // Check if the phase is valid.
1219   if ((!py::isinstance<py::str>(phase))) {
1220     MS_LOG(ERROR) << "The `phase` must be string.";
1221     return false;
1222   }
1223   // Check if the function or net is valid.
1224   if (py::isinstance<py::none>(source)) {
1225     MS_LOG(ERROR) << "The source object to compile should not be None.";
1226     return false;
1227   }
1228   // Check if the args of function or net is valid.
1229   CheckArgsValid(source, args);
1230 
1231   source_ = py::cast<std::string>(py::str(source));
1232   phase_ = py::cast<std::string>(phase);
1233   PhaseManager::GetInstance().set_phase(phase_);
1234   obj_desc_ = GetObjDesc(source);
1235   MS_LOG(INFO) << "Start compiling, phase: " << phase_;
1236   PROF_START(compile_graph);
1237   MS_LOG(DEBUG) << "source: {" << source_ << "}\nargs: " << py::str(const_cast<py::tuple &>(args))
1238                 << "\nkwargs: " << py::str(const_cast<py::dict &>(kwargs));
1239   EventMessage::PrintCompileStartMsg(phase_, obj_desc_);
1240 
1241   ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
1242   ResourcePtr resource = std::make_shared<Resource>(source);
1243   InitCompileCacheInfo(resource, phase_);
1244   bool enable_compile_cache = resource->EnableCompileCache();
1245   bool use_compile_cache = enable_compile_cache && resource->func_graph();
1246   ConfigManager::GetInstance().ResetQueue(queue_name_);
1247   auto &compile_cache_context = CompileCacheContext::GetInstance();
1248   compile_cache_context.SetUseCompileCache(use_compile_cache);
1249 
1250   auto actions = GetActions(resource, phase_, use_vm, false, false);
1251   std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, actions);
1252 
1253   (void)profiler::CollectHostInfo(kCompiler, kCreateBackend, kCreateBackend, 0, 0, 0);
1254   if (pip->NeedCreateBackend()) {
1255     // Create backend asynchronously.
1256     resource->SetBackendAsync([]() {
1257       auto backend = compile::CreateBackend();
1258 #ifdef ENABLE_DEBUGGER
1259       // Connect session to debugger.
1260       backend->SetDebugger();
1261 #endif
1262       return backend;
1263     });
1264   }
1265   (void)profiler::CollectHostInfo(kCompiler, kCreateBackend, kCreateBackend, 0, 0, 1);
1266 
1267   // Get the parameters items and add the value to args_abs.
1268   abstract::AbstractBasePtrList args_abs;
1269   std::vector<ValuePtr> arguments;
1270   MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
1271   bool is_parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kSemiAutoParallel ||
1272                           parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kAutoParallel;
1273   bool is_auto_parallel = is_parallel_mode && !py::hasattr(source, parallel::kSkipAutoParallelCompile) &&
1274                           !py::hasattr(source, parallel::kKeepInputUnchanged);
1275   ConvertArgs(args, kwargs, is_auto_parallel, &args_abs, &arguments);
1276   ConvertSymbolicShape(args, &args_abs);
1277   AddManagerForFuncGraphArgs(resource, arguments);
1278   resource->set_arguments(arguments);
1279   resource->set_args_abs(args_abs);
1280   executor_info->arg_list_size = args.size() + kwargs.size();
1281   executor_info->resource = resource;
1282   info_[phase_] = executor_info;
1283   pip->Run();
1284 
1285   // Save the compiled graph to MsPipeLine.
1286   SaveCompiledGraph(phase_);
1287   if (is_parallel_mode) {
1288     ParallelPostProcess(phase_, use_compile_cache);
1289   }
1290 #ifdef ENABLE_DUMP_IR
1291   mindspore::RDR::Snapshot();
1292 #endif
1293   CleanCompileRes(resource);
1294   EventMessage::PrintCompileEndMsg(phase_, obj_desc_);
1295   PhaseManager::GetInstance().ClearPhase();
1296   MS_LOG(INFO) << "Finish compiling.";
1297   PROF_END(compile_graph);
1298   return true;
1299 }
1300 
ConvertArgs(const py::tuple & args,const py::dict & kwargs,bool is_auto_parallel,abstract::AbstractBasePtrList * args_abs,std::vector<ValuePtr> * arguments)1301 void GraphExecutorPy::ConvertArgs(const py::tuple &args, const py::dict &kwargs, bool is_auto_parallel,
1302                                   abstract::AbstractBasePtrList *args_abs, std::vector<ValuePtr> *arguments) {
1303   MS_EXCEPTION_IF_NULL(args_abs);
1304   MS_EXCEPTION_IF_NULL(arguments);
1305   for (std::size_t i = 0; i < args.size(); i++) {
1306     // In some parallel mode need full_tensor which cause the args of GenerateArgumentsKey not same to compile,
1307     // So can't use cur_convert_input_ directly.
1308     auto iter = cur_convert_input_.find(args[i].ptr());
1309     if (iter != cur_convert_input_.end()) {
1310       (void)arguments->emplace_back(iter->second.first);
1311       if (is_auto_parallel) {
1312         auto abs_item = iter->second.second->Clone();
1313         (void)parallel::ExtendInputArgsAbstractShape(abs_item, i);
1314         (void)args_abs->emplace_back(abs_item);
1315         continue;
1316       }
1317       (void)args_abs->emplace_back(iter->second.second);
1318       continue;
1319     }
1320     ValuePtr converted = nullptr;
1321     bool success = parse::ConvertData(args[i], &converted);
1322     if (!success) {
1323       MS_LOG(INTERNAL_EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i
1324                                  << "]: " << py::str(args[i]);
1325     }
1326     (void)arguments->emplace_back(converted);
1327     auto args_abstract_item = ArgsToAbstract(args[i], converted, enable_tuple_broaden_);
1328     if (is_auto_parallel) {
1329       (void)parallel::ExtendInputArgsAbstractShape(args_abstract_item, i);
1330     }
1331     (void)args_abs->emplace_back(args_abstract_item);
1332   }
1333   for (const auto &item : kwargs) {
1334     auto iter = cur_convert_input_.find(item.first.ptr());
1335     if (iter != cur_convert_input_.end()) {
1336       (void)arguments->emplace_back(iter->second.first);
1337       (void)args_abs->emplace_back(iter->second.second);
1338       continue;
1339     }
1340     ValuePtr key = nullptr;
1341     ValuePtr value = nullptr;
1342     bool success = parse::ConvertData(py::cast<py::object>(item.first), &key) &&
1343                    parse::ConvertData(py::cast<py::object>(item.second), &value);
1344     if (!success) {
1345       MS_LOG(INTERNAL_EXCEPTION) << "Fail to convert the argument (" << py::str(item.first) << ": "
1346                                  << py::str(item.second) << ").";
1347     }
1348     AbstractBasePtr value_abs = ArgsToAbstract(py::cast<py::object>(item.second), value, enable_tuple_broaden_);
1349     auto keyword_arg_abs = std::make_shared<abstract::AbstractKeywordArg>(GetValue<std::string>(key), value_abs);
1350     (void)arguments->emplace_back(value);
1351     (void)args_abs->emplace_back(keyword_arg_abs);
1352   }
1353 }
1354 
ConvertSymbolicShape(const py::tuple & args,AbstractBasePtrList * args_abs)1355 void GraphExecutorPy::ConvertSymbolicShape(const py::tuple &args, AbstractBasePtrList *args_abs) {
1356   std::vector<symshape::SymbolInfoList> symbol_infos;
1357   symbol_infos.reserve(args_abs->size());
1358   bool has_dyn_shape = false;
1359   bool is_parallel = parallel::IsSemiOrAutoParallelMode();
1360 
1361   for (size_t i = 0; i < args.size(); i++) {
1362     auto iter = cur_convert_input_.find(args[i].ptr());
1363     if (iter == cur_convert_input_.end()) {
1364       continue;
1365     }
1366     auto &info_list = symbol_infos.emplace_back(symshape::SymbolInfoList{});
1367     if (!iter->second.first->isa<MetaTensor>()) {
1368       continue;
1369     }
1370     auto digital_shape = iter->second.second->GetShape();
1371     if (digital_shape->IsDynamic()) {
1372       has_dyn_shape = true;
1373     }
1374     constexpr char symbolic_shape_attr[] = "symbolic_shape";
1375     if (!py::hasattr(args[i], symbolic_shape_attr)) {
1376       if (is_parallel) {
1377         if (digital_shape != nullptr && digital_shape->isa<abstract::TensorShape>()) {
1378           info_list.resize(digital_shape->GetShapeVector().size());
1379         }
1380       }
1381       continue;
1382     }
1383     auto symbolic_shape_obj = py::getattr(args[i], symbolic_shape_attr);
1384     MS_EXCEPTION_IF_CHECK_FAIL(py::isinstance<py::list>(symbolic_shape_obj), "tensor.symbolic_shape should be a list");
1385     auto obj_list = py::cast<py::list>(symbolic_shape_obj);
1386     info_list.resize(obj_list.size());
1387     for (size_t j = 0; j < obj_list.size(); j++) {
1388       if (!py::isinstance<py::dict>(obj_list[j])) {
1389         continue;
1390       }
1391       auto dict_obj = py::cast<py::dict>(obj_list[j]);
1392       for (auto cfg_iter = dict_obj.begin(); cfg_iter != dict_obj.end(); ++cfg_iter) {
1393         auto cfg_key = py::cast<std::string>(cfg_iter->first);
1394         if (cfg_key == "max") {
1395           info_list[j].max = py::cast<int64_t>(cfg_iter->second);
1396         } else if (cfg_key == "min") {
1397           info_list[j].min = py::cast<int64_t>(cfg_iter->second);
1398         } else if (cfg_key == "divisor") {
1399           info_list[j].divisor = py::cast<int64_t>(cfg_iter->second);
1400         } else if (cfg_key == "remainder") {
1401           info_list[j].remainder = py::cast<int64_t>(cfg_iter->second);
1402         } else if (cfg_key == "id") {
1403           info_list[j].id = py::cast<int64_t>(cfg_iter->second);
1404         } else if (cfg_key == "name") {
1405           info_list[j].name = py::cast<std::string>(cfg_iter->second);
1406         }
1407       }
1408     }
1409   }
1410 
1411   MS_LOG(DEBUG) << "before parallel symbol";
1412   parallel::PrintSymbolInfo(symbol_infos);
1413   symbol_infos = parallel::ParallelSymbolInfo(symbol_infos, has_dyn_shape);
1414   MS_LOG(DEBUG) << "after parallel symbol";
1415   parallel::PrintSymbolInfo(symbol_infos);
1416 
1417   auto symbolic_shape_list = symshape::BuildSymbolicShapeBySymbolInfo(*args_abs, symbol_infos);
1418   for (size_t i = 0; i < symbolic_shape_list.size(); i++) {
1419     // when the same tensor object is used in set_inputs interface, the inputs may shared a same Abstract object.
1420     // but for dynamic shape, the same "-1" in abstract can be different symbolic shape.
1421     auto abs = symshape::CloneAbstractIfSymbolExists((*args_abs)[i]);
1422     MS_EXCEPTION_IF_NULL(abs);
1423     abs->SetSymbolicShape(symbolic_shape_list[i]);
1424     (*args_abs)[i] = abs;
1425   }
1426 }
1427 
ReleaseResourceOnException(const py::object & phase)1428 void GraphExecutorPy::ReleaseResourceOnException(const py::object &phase) {
1429   bool clear = false;
1430   // Be sure the pointer res destroyed before do DelOneNetRes.
1431   {
1432     ResourcePtr res = GetResource(py::cast<std::string>(phase));
1433     if (res != nullptr) {
1434       clear = true;
1435       CleanCompileRes(res);
1436     }
1437   }
1438   ProcessStatus::GetInstance().Clear();
1439   if (clear) {
1440     DelOneNetRes(phase);
1441   }
1442 }
1443 
Compile(const py::object & source,const py::tuple & args,const py::dict & kwargs,const py::object & phase,bool use_vm)1444 bool GraphExecutorPy::Compile(const py::object &source, const py::tuple &args, const py::dict &kwargs,
1445                               const py::object &phase, bool use_vm) {
1446   bool res = false;
1447   HandleExceptionRethrow(
1448     [this, &res, &source, &args, &kwargs, &phase, use_vm]() {
1449       if (executor_running_) {
1450         MS_LOG(EXCEPTION) << "Nested execution during JIT execution for " << GetObjDesc(source) << " is not supported "
1451                           << "when " << obj_desc_ << " compile and execute. For more details, please refer to "
1452                           << "https://www.mindspore.cn/search?inputValue=Nested%20execution";
1453       }
1454       ProcessStatus::GetInstance().RecordStart(kCompiler);
1455       std::map<std::string, std::string> custom_info;
1456       custom_info["phase"] = py::cast<std::string>(phase);
1457       (void)profiler::CollectHostInfo(kCompiler, kCompiler, kCompiler, 1, 0, 0, custom_info);
1458       res = CompileInner(source, args, kwargs, phase, use_vm);
1459       (void)profiler::CollectHostInfo(kCompiler, kCompiler, kCompiler, 1, 0, 1, custom_info);
1460       ProcessStatus::GetInstance().RecordEnd();
1461       ProcessStatus::GetInstance().Print();
1462     },
1463     [this, &phase]() {
1464       if (!StaticAnalysisException::Instance().HasException()) {
1465         // print function call stack info before release
1466         std::string compile_exception_info = GetCompileExceptionInfo();
1467         if (!compile_exception_info.empty()) {
1468           MS_LOG(ERROR) << compile_exception_info;
1469         }
1470       }
1471       ReleaseResourceOnException(phase);
1472     },
1473     [this, &phase]() { ReleaseResourceOnException(phase); }, [this, &phase]() { ReleaseResourceOnException(phase); });
1474   return res;
1475 }
1476 
CacheFuncGraph(const ResourcePtr & resource)1477 void CacheFuncGraph(const ResourcePtr &resource) {
1478   if (!resource->EnableCompileCache()) {
1479     return;
1480   }
1481   {
1482     MsProfileStatGuard stat_guard("SaveCacheFuncGraph");
1483     resource->CacheFuncGraph();
1484   }
1485 }
1486 
CheckInterpretNodeLineInfos()1487 void CheckInterpretNodeLineInfos() {
1488   auto &py_interpret_nodes = InterpretNodeRecorder::GetInstance().PyInterpretNodes();
1489   auto &py_execute_nodes = InterpretNodeRecorder::GetInstance().PyExecuteNodes();
1490   if (py_interpret_nodes.empty() && py_execute_nodes.empty()) {
1491     return;
1492   }
1493 
1494   std::stringstream ss;
1495   ss << "Found unsupported syntax in graph mode, those codes would be fallen back to Python interpreter:\n";
1496   // Dump for PyInterpret.
1497   ss << "----------------------------------------\n";
1498   ss << " After Parser Phase (total: " << py_interpret_nodes.size() << ")\n";
1499   ss << "----------------------------------------\n";
1500   size_t num = 1;
1501   for (const auto &node : py_interpret_nodes) {
1502     const auto line_info = trace::GetDebugInfoStr(node->debug_info());
1503     ss << "# No. " << num << ":\n" << line_info << "\n";
1504     ++num;
1505   }
1506   ss << "\n";
1507   // Dump for PyExecute.
1508   ss << "----------------------------------------\n";
1509   ss << " After Optimizer Phase (total: " << py_execute_nodes.size() << ")\n";
1510   ss << "----------------------------------------\n";
1511   num = 1;
1512   for (const auto &node : py_execute_nodes) {
1513     ss << "# No. " << num << ":\n";
1514     const auto &cnode = node->cast<CNodePtr>();
1515     MS_EXCEPTION_IF_NULL(cnode);
1516     const auto &weak_script_node = cnode->weak_input(1);
1517     const auto &script_node = weak_script_node.lock();
1518     MS_EXCEPTION_IF_NULL(script_node);
1519     const auto &script = GetValueNode<StringImmPtr>(script_node);
1520     // Usually the script is a value node.
1521     std::string script_str;
1522     if (script != nullptr) {
1523       script_str = script->value();
1524     } else {
1525       const auto &script_abs = script_node->abstract();
1526       if (script_abs != nullptr) {
1527         const auto script_abs_scalar = script_abs->cast<abstract::AbstractScalarPtr>();
1528         auto script_value = script_abs_scalar->BuildValue();
1529         MS_EXCEPTION_IF_NULL(script_value);
1530         auto script_value_str = script_value->cast<StringImmPtr>();
1531         MS_EXCEPTION_IF_NULL(script_value_str);
1532         script_str = script_value_str->value();
1533       }
1534     }
1535     if (!script_str.empty()) {
1536       ss << "Script: " << script_str << "\n\n";
1537     } else {
1538       ss << "Node: " << node->DebugString() << "\n\n";
1539     }
1540     const auto line_info = trace::GetDebugInfoStr(node->debug_info());
1541     ss << line_info << "\n";
1542     ++num;
1543   }
1544   ss << "\n";
1545   ss << "----------------------------------------\n";
1546 
1547   // Print the codes run in JIT Fallback.
1548   if (common::GetEnv("MS_DEV_FALLBACK_DUMP_NODE") == "1") {
1549     MS_LOG(ERROR) << ss.str();
1550   } else {
1551     MS_LOG(INFO) << ss.str();
1552   }
1553   InterpretNodeRecorder::GetInstance().Clear();
1554 }
1555 
1556 #ifdef ENABLE_DUMP_IR
RDRRecordGraph(const size_t action_index,const size_t action_size,const std::string & filename,const FuncGraphPtr & graph)1557 void RDRRecordGraph(const size_t action_index, const size_t action_size, const std::string &filename,
1558                     const FuncGraphPtr &graph) {
1559   if (mindspore::RecorderManager::Instance().RdrEnable()) {
1560     MS_LOG(INFO) << "Recording FuncGraph in pipeline using RDR.";
1561     if (graph != nullptr) {
1562       auto graph_clone = BasicClone(graph);
1563       if (graph_clone != nullptr) {
1564         DumpGraphParams dump_params = {false, static_cast<int>(kTopStack)};
1565         if (action_index == action_size) {
1566           dump_params.dump_mode = static_cast<int>(kWholeStack);
1567         }
1568         (void)mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, filename, graph_clone, dump_params, ".ir");
1569       } else {
1570         MS_LOG(WARNING) << "Clone FuncGraph failed in pipeline, no FuncGraph recording in RDR.";
1571       }
1572     } else {
1573       MS_LOG(WARNING) << "Pipeline Resource has no FuncGraph, no FuncGraph recording in RDR";
1574     }
1575     MS_LOG(INFO) << "Recording FuncGraph in pipeline end.";
1576   }
1577 }
1578 #endif
1579 
1580 #ifdef ENABLE_DUMP_IR
RecordIR(const size_t action_index,const size_t action_size,const std::string & action_name,const FuncGraphPtr & graph,FuncGraphPtr * user_graph)1581 void RecordIR(const size_t action_index, const size_t action_size, const std::string &action_name,
1582               const FuncGraphPtr &graph, FuncGraphPtr *user_graph) {
1583   auto context = MsContext::GetInstance();
1584   MS_EXCEPTION_IF_NULL(context);
1585   if (context->CanDump(kIntroductory) && graph != nullptr) {
1586     *user_graph = graph;
1587     std::string base_name = GetBaseNameForIR(SizeToLong(action_index), action_name);
1588 
1589     // Generate IR file in human-readable format
1590     static const auto switch_order = (common::GetEnv("MS_DEV_SAVE_GRAPHS_SORT_MODE") == "1");
1591     if (switch_order) {
1592       ExportIR(base_name + ".ir", graph);
1593     } else {
1594       DumpIR(base_name + ".ir", graph, true, kWholeStack);
1595     }
1596     if (context->CanDump(kFully)) {
1597       draw::Draw(base_name + ".dot", graph);
1598     }
1599   }
1600 }
1601 #endif
1602 
1603 #ifndef ENABLE_SECURITY
SaveGraphForReadability(const std::string & action_name,const FuncGraphPtr & graph,const ResourcePtr & resource)1604 void SaveGraphForReadability(const std::string &action_name, const FuncGraphPtr &graph, const ResourcePtr &resource) {
1605   if (graph != nullptr && action_name.find("optimize") != string::npos) {
1606 #ifdef ENABLE_DUMP_IR
1607     auto context = MsContext::GetInstance();
1608     MS_EXCEPTION_IF_NULL(context);
1609     if (context->CanDump(kIntroductory)) {
1610       DumpIRProto(graph, action_name);
1611     }
1612 #endif
1613     resource->set_optimize_graph(graph);
1614   }
1615 }
1616 #endif
1617 
Run()1618 void Pipeline::Run() {
1619   MS_LOG(INFO) << "Pipeline run";
1620   MS_EXCEPTION_IF_NULL(resource_);
1621   FuncGraphPtr user_graph = nullptr;
1622   const std::string last_compile_action = kValidate;
1623   bool already_print_profile = false;
1624   static const auto compile_profile_finish_action = common::GetCompileConfig("COMPILE_PROFILE_FINISH_ACTION");
1625   ProfileExecute(MsProfile::GetProfile(), [this, &user_graph, &last_compile_action, &already_print_profile]() {
1626     size_t i = 0;
1627     for (auto &action : actions_) {
1628 #ifdef ENABLE_TIMELINE
1629       DumpTime &dump_time = DumpTime::GetInstance();
1630       dump_time.Record(action.first, GetTime(), true);
1631 #endif
1632       ProcessStatus::GetInstance().RecordStart(action.first);
1633       (void)profiler::CollectHostInfo(kCompiler, action.first, action.first, 0, 0, 0);
1634       bool result = true;
1635       ProfileExecute(MsProfile::GetProfile()->Step(action.first), [&result, &action, this]() {
1636         MS_LOG(INFO) << "Status record: start " << action.first << " action.";
1637         result = action.second(resource_);
1638         MS_LOG(INFO) << "Status record: end " << action.first << " action.";
1639         if (IS_OUTPUT_ON(mindspore::kInfo)) {
1640           auto manager = resource_->func_graph()->manager();
1641           MS_EXCEPTION_IF_NULL(manager);
1642           MS_LOG(INFO) << "Extra status record: total func graphs: " << manager->func_graphs().size()
1643                        << ", total nodes: " << manager->all_nodes().size();
1644         }
1645       });
1646       (void)profiler::CollectHostInfo(kCompiler, action.first, action.first, 0, 0, 1);
1647       ProcessStatus::GetInstance().RecordEnd();
1648       if (!result) {
1649         MS_LOG(INTERNAL_EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;
1650       }
1651 
1652       if (EnabledProfile() && compile_profile_finish_action == action.first) {
1653         ProfileExecuteBreak(MsProfile::GetProfile());
1654         MsProfile::Print();
1655         already_print_profile = true;
1656       }
1657 
1658       if (action.first == kTaskEmit) {
1659         SetLoopCount(resource_);
1660       } else if (action.first == last_compile_action) {
1661         CheckInterpretNodeLineInfos();
1662         CacheFuncGraph(resource_);
1663 #ifndef ENABLE_SECURITY
1664 #ifdef WITH_BACKEND
1665         MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
1666         if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
1667           const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1668             {kAscendDevice, MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
1669           MS_EXCEPTION_IF_NULL(device_context);
1670           MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
1671           device_context->GetDeprecatedInterface()->DumpProfileParallelStrategy(resource_->func_graph());
1672         }
1673 #endif
1674 #endif
1675         ResetId(resource_);
1676       }
1677       FuncGraphPtr graph = resource_->func_graph();
1678 #ifdef ENABLE_DUMP_IR
1679       std::string filename = GetBaseNameForIR(SizeToLong(i), action.first);
1680       RDRRecordGraph(i, actions_.size(), filename, graph);
1681       RecordIR(i, actions_.size(), action.first, graph, &user_graph);
1682 #endif
1683 #ifndef ENABLE_SECURITY
1684       SaveGraphForReadability(action.first, graph, resource_);
1685 #endif
1686       i++;
1687 #ifdef ENABLE_TIMELINE
1688       dump_time.Record(action.first, GetTime(), false);
1689 #endif
1690     }
1691   });
1692 
1693   if (EnabledProfile()) {
1694     if (!already_print_profile) {
1695       MsProfile::Print();
1696     }
1697     MsProfile::Reset();
1698   }
1699 
1700 #ifdef ENABLE_DUMP_IR
1701   auto context = MsContext::GetInstance();
1702   MS_EXCEPTION_IF_NULL(context);
1703   if (context->CanDump(kIntroductory) && (user_graph != nullptr)) {
1704     if (context->CanDump(kFully)) {
1705       draw::DrawUserFuncGraph("ModelDigraph.dot", user_graph);
1706     }
1707   }
1708   if (common::GetEnv("DUMP_PARALLEL_INFO") == "1") {
1709     std::unordered_map<std::string, std::vector<uint32_t>> group_map;
1710     if (distributed::collective::CollectiveManager::instance()->initialized()) {
1711       group_map = distributed::collective::CollectiveManager::instance()->get_group_map();
1712     }
1713     if (parallel::g_device_manager == nullptr) {
1714       MS_LOG(WARNING) << "parallel::g_device_manager is not initialized. Skip dump parallel info.";
1715     } else {
1716       auto global_rank_id = parallel::g_device_manager->global_rank();
1717       DumpParallelJson("dump_parallel_info_" + std::to_string(global_rank_id) + ".json", resource_->func_graph(),
1718                        global_rank_id, group_map);
1719     }
1720   }
1721 #endif
1722   MS_LOG(INFO) << "End";
1723 }
1724 
NeedCreateBackend()1725 bool Pipeline::NeedCreateBackend() {
1726   return std::any_of(actions_.begin(), actions_.end(),
1727                      [](const ActionItem &action) { return action.first == kTaskEmit || action.first == kExecute; });
1728 }
1729 
ProcessVmArgInner(const py::tuple & args,const ResourcePtr & res,VectorRef * const arg_list)1730 void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) {
1731   MS_EXCEPTION_IF_NULL(arg_list);
1732   bool arg_list_inited = !arg_list->empty();
1733   for (std::size_t i = 0; i < args.size(); i++) {
1734     py::object arg = args[i];
1735     ValuePtr converted = nullptr;
1736     bool succ = parse::ConvertData(arg, &converted);
1737     if (!succ) {
1738       MS_LOG(INTERNAL_EXCEPTION) << "The " << i << "th arg convert failed.";
1739     }
1740     if (!arg_list_inited) {
1741       arg_list->push_back(converted);
1742       continue;
1743     }
1744     if (i >= arg_list->size()) {
1745       MS_LOG(INTERNAL_EXCEPTION) << "i:" << i << " output of range:" << arg_list->size();
1746     }
1747     (*arg_list)[i] = converted;
1748   }
1749 
1750   MS_EXCEPTION_IF_NULL(res);
1751   auto graph = res->func_graph();
1752   MS_EXCEPTION_IF_NULL(graph);
1753   const std::vector<AnfNodePtr> &graph_params = graph->parameters();
1754   std::size_t graph_params_size = graph_params.size();
1755   if ((*arg_list).size() != graph_params_size) {
1756     // Maybe some default parameter
1757     for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) {
1758       MS_EXCEPTION_IF_NULL(graph_params[i]);
1759       auto param_ptr = (graph_params[i])->cast_ptr<Parameter>();
1760       MS_EXCEPTION_IF_NULL(param_ptr);
1761       if (!param_ptr->has_default()) {
1762         MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
1763       }
1764       if (!param_ptr->default_param()->isa<Tensor>()) {
1765         MS_LOG(EXCEPTION) << "Parameter[" << param_ptr->ToString()
1766                           << "] is not initialized, need to call `.init_data()`";
1767       }
1768       arg_list->push_back(param_ptr->default_param());
1769     }
1770   }
1771 }
1772 
ProcessVmArg(const py::tuple & args,const std::string & phase,VectorRef * const arg_list)1773 void GraphExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list) {
1774   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kGraphExecutorPy, runtime::ProfilerEvent::kInputProcess,
1775                                      phase);
1776   ProcessVmArgInner(args, GetResource(phase), arg_list);
1777 }
1778 
1779 #ifdef ENABLE_DEBUGGER
TerminateDebugger()1780 void GraphExecutorPy::TerminateDebugger() {
1781   if (Common::GetDebugTerminate()) {
1782     MS_LOG(INFO) << "Terminate debugger and clear resources!";
1783     ClearResAtexit();
1784     exit(static_cast<int>(!Common::GetDebugExitSuccess()));
1785   }
1786 }
1787 #endif
1788 
Run(const py::tuple & args,const py::object & phase)1789 py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase) {
1790   py::object res;
1791   HandleExceptionRethrow(
1792     [this, &res, &args, &phase]() {
1793       executor_running_ = true;
1794 
1795       uint64_t start_time = 0;
1796       PROFILER_START(start_time);
1797       res = RunInner(args, phase);
1798       PROFILER_STAGE_END(start_time, runtime::ProfilerStage::kRunGraph);
1799 
1800       executor_running_ = false;
1801     },
1802     [this]() { executor_running_ = false; }, [this]() { executor_running_ = false; },
1803     [this]() { executor_running_ = false; }, nullptr, true);
1804   return res;
1805 }
1806 
1807 #ifdef WITH_BACKEND
GeFirstInitParams()1808 void GraphExecutorPy::GeFirstInitParams() {
1809   static bool inited = false;
1810   if (!inited) {
1811     MS_LOG(INFO) << "Start init params.";
1812     const auto &init_params = GetParams(phase_);
1813     auto ret = InitParams(init_params, phase_);
1814     if (ret) {
1815       inited = true;
1816     }
1817   }
1818 }
1819 #endif
1820 
ClearRunArgumentsResource(size_t input_arg_size,VectorRef * arg_list)1821 void GraphExecutorPy::ClearRunArgumentsResource(size_t input_arg_size, VectorRef *arg_list) {
1822   for (std::size_t i = 0; i < input_arg_size; ++i) {
1823     (*arg_list)[i] = nullptr;
1824   }
1825 }
1826 
RunInner(const py::tuple & args,const py::object & phase_obj)1827 py::object GraphExecutorPy::RunInner(const py::tuple &args, const py::object &phase_obj) {
1828   if (common::GetEnv(kSimulationLevel) == kSimulationLevelCompileGraph) {
1829     py::int_ ret = 0;
1830     return ret;
1831   }
1832   // Init for dynamic-obfuscated model infer
1833   (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
1834   // Mindspore debugger notify main thread to exit after one step, and will not run next step
1835 #ifdef ENABLE_DEBUGGER
1836   TerminateDebugger();
1837 #endif
1838   if (!py::isinstance<py::str>(phase_obj)) {
1839     MS_LOG(INTERNAL_EXCEPTION) << "Run failed, phase input is not a str";
1840   }
1841   auto phase = py::cast<std::string>(phase_obj);
1842   auto phase_prefix = GetPhasePrefix(phase);
1843   PhaseManager::GetInstance().set_phase(phase_prefix);
1844   auto ms_context = MsContext::GetInstance();
1845   MS_EXCEPTION_IF_NULL(ms_context);
1846   static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
1847   if (enable_infer_boost) {
1848     PhaseManager::GetInstance().set_phase(phase);
1849   }
1850 #ifdef WITH_BACKEND
1851   if (ms_context->backend_policy() == "ge") {
1852     if (!IsEnableRefMode()) {
1853       GeFirstInitParams();
1854     }
1855 
1856     if (phase_prefix == "save") {
1857       auto pos = phase.find('.');
1858       std::string origin_phase = phase.substr(pos + 1);
1859       FuncGraphPtr func_graph = info_["train." + origin_phase]->func_graph;
1860       MS_EXCEPTION_IF_NULL(func_graph);
1861       MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
1862       auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1863         {MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET),
1864          MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
1865       MS_EXCEPTION_IF_NULL(device_context);
1866       MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
1867       device_context->GetDeprecatedInterface()->DoExecNonInputGraph("save." + func_graph->ToString());
1868       ConfigManager::GetInstance().ResetConfig();
1869       return py::none();
1870     }
1871   }
1872 #endif
1873   auto ret_val = std::make_shared<py::object>();
1874   if (info_.count(phase) != 0 && info_[phase]->func_graph != nullptr) {
1875     if (IsGraphOutputValueNodeOrParameter(info_[phase]->func_graph->output(), args, ret_val)) {
1876       return *ret_val;
1877     }
1878   }
1879 #ifndef WITH_BACKEND
1880   if (ms_context->backend_policy() == "ge") {
1881     // Virtual output constructed for test cases.
1882     if (!args.empty()) {
1883       return args[0];
1884     }
1885     return args;
1886   }
1887 #endif
1888   auto iter = info_.find(phase);
1889   if (iter == info_.end()) {
1890     MS_LOG(INTERNAL_EXCEPTION) << "No executor info. found for phase: " << phase;
1891   }
1892   auto &execute_info = iter->second;
1893   MS_EXCEPTION_IF_NULL(execute_info);
1894   if (args.size() > execute_info->arg_list_size) {
1895     MS_LOG(WARNING) << "The args size: " << args.size() << ", full_arg_size: " << execute_info->arg_list_size;
1896   }
1897   ProcessVmArg(args, phase, &execute_info->arg_list);
1898   // Start to run phase.
1899   compile::VmEvalFuncPtr run = GetVmEvalFunc(phase);
1900   if (run == nullptr) {
1901     MS_LOG(INTERNAL_EXCEPTION) << "Can't find run graph func for " << phase;
1902   }
1903 
1904   MS_LOG(DEBUG) << "Eval run " << ms_context->backend_policy();
1905   const auto &output = execute_info->func_graph->output();
1906   MS_EXCEPTION_IF_NULL(output);
1907   const auto &output_abs = output->abstract();
1908   MS_EXCEPTION_IF_NULL(output_abs);
1909   BaseRef value = (*run)(execute_info->arg_list);
1910   bool need_recovery = distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
1911                        distributed::recovery::RecoveryContext::GetInstance()->need_reset();
1912   if (need_recovery) {
1913     // In recovery scenario, the output value could be empty, do not transform return data.
1914     return py::none();
1915   }
1916   py::object res = BaseRefToPyDataWithUserData(value, output_abs);
1917   ClearRunArgumentsResource(args.size(), &execute_info->arg_list);
1918   PhaseManager::GetInstance().ClearPhase();
1919   MS_LOG(DEBUG) << "Run end";
1920   return res;
1921 }  // namespace pipeline
1922 
InitParams(const py::dict & init_params,const std::string & phase) const1923 bool GraphExecutorPy::InitParams(const py::dict &init_params, const std::string &phase) const {
1924   MS_LOG(INFO) << "Init params when ge backend, phase = " << phase;
1925   if (info_.count(phase) == 0) {
1926     MS_LOG(INTERNAL_EXCEPTION) << "No phase in executor: " << GetPhasePrefix(phase);
1927   }
1928   DeviceContext *device_context = nullptr;
1929   try {
1930     auto ms_context = MsContext::GetInstance();
1931     MS_EXCEPTION_IF_NULL(ms_context);
1932     auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1933     device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({kAscendDevice, device_id});
1934   } catch (const std::exception &) {
1935     return false;
1936   }
1937   MS_EXCEPTION_IF_NULL(device_context);
1938   MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
1939   return device_context->GetDeprecatedInterface()->RunInitGraph(info_.at(phase)->func_graph, init_params);
1940 }
1941 
BuildGraph(const py::dict & init_params,const std::string & phase) const1942 FuncGraphPtr GraphExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase) const {
1943   MS_LOG(INFO) << "Start build df graph, phase = " << phase;
1944   if (info_.count(phase) == 0) {
1945     MS_LOG(INTERNAL_EXCEPTION) << "No phase in executor: " << GetPhasePrefix(phase);
1946   }
1947   DeviceContext *device_context = nullptr;
1948   try {
1949     auto ms_context = MsContext::GetInstance();
1950     MS_EXCEPTION_IF_NULL(ms_context);
1951     auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1952     device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({kAscendDevice, device_id});
1953   } catch (const std::exception &) {
1954     return nullptr;
1955   }
1956   MS_EXCEPTION_IF_NULL(device_context);
1957   MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
1958   return device_context->GetDeprecatedInterface()->BuildDFGraph(info_.at(phase)->func_graph, init_params);
1959 }
1960 
UpdataParamNodeDefaultInput(const std::string & phase,const std::unordered_map<std::string,tensor::TensorPtr> & params_value)1961 void GraphExecutorPy::UpdataParamNodeDefaultInput(
1962   const std::string &phase, const std::unordered_map<std::string, tensor::TensorPtr> &params_value) {
1963   FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
1964   MS_EXCEPTION_IF_NULL(func_graph);
1965   MS_LOG(DEBUG) << "UpdataParamNodeDefaultInput for func graph(" << func_graph->ToString() << ") phase(" << phase
1966                 << ")!";
1967   auto &params = func_graph->parameters();
1968   for (const auto &param : params) {
1969     MS_EXCEPTION_IF_NULL(param);
1970     auto param_cast = param->cast_ptr<Parameter>();
1971     MS_EXCEPTION_IF_NULL(param_cast);
1972     auto iter = params_value.find(param_cast->name());
1973     if (iter != params_value.end()) {
1974       param_cast->set_default_param(iter->second);
1975     }
1976   }
1977 }
1978 
GetParams(const std::string & phase)1979 py::dict GraphExecutorPy::GetParams(const std::string &phase) {
1980   FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
1981   MS_EXCEPTION_IF_NULL(func_graph);
1982   py::dict parameter_dict;
1983   std::vector<AnfNodePtr> graph_params = func_graph->parameters();
1984   for (auto &param : graph_params) {
1985     MS_EXCEPTION_IF_NULL(param);
1986     auto param_ptr = std::static_pointer_cast<Parameter>(param);
1987     std::string name = param_ptr->name();
1988     auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_ptr->default_param());
1989     if (tensor != nullptr) {
1990       parameter_dict[py::str(name)] = *tensor;
1991     }
1992   }
1993   return parameter_dict;
1994 }
1995 
GetRandomStatus(const std::string & phase) const1996 py::bytes GraphExecutorPy::GetRandomStatus(const std::string &phase) const {
1997   auto iter = info_.find(phase);
1998   if (iter == info_.end()) {
1999     MS_LOG(ERROR) << "Phase " << phase << " must compile.";
2000     return "";
2001   }
2002   MS_EXCEPTION_IF_NULL(iter->second);
2003   MS_EXCEPTION_IF_NULL(iter->second->resource);
2004   auto &resource = iter->second->resource;
2005   auto backend = resource->GetBackend();
2006   const auto &mindrt_backend = std::dynamic_pointer_cast<compile::MindRTBackend>(backend);
2007   MS_EXCEPTION_IF_NULL(mindrt_backend);
2008   auto actor_info = resource->GetResult(kActorInfo).cast<compile::ActorInfo>();
2009   auto random_status = mindrt_backend->GetRandomStatus(actor_info);
2010   return py::bytes(random_status.c_str(), random_status.size());
2011 }
2012 
PyExePath(const py::object & py_exe_path) const2013 void GraphExecutorPy::PyExePath(const py::object &py_exe_path) const {
2014   if (!py::isinstance<py::str>(py_exe_path)) {
2015     MS_LOG(INTERNAL_EXCEPTION) << "Failed, py_exe_path input is not a str";
2016   }
2017   auto py_exe_path_s = py::cast<std::string>(py_exe_path);
2018   auto ms_context = MsContext::GetInstance();
2019   ms_context->set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, py_exe_path_s);
2020 }
2021 
KernelBuildServerDir(const py::object & kernel_build_server_dir) const2022 void GraphExecutorPy::KernelBuildServerDir(const py::object &kernel_build_server_dir) const {
2023   if (!py::isinstance<py::str>(kernel_build_server_dir)) {
2024     MS_LOG(INTERNAL_EXCEPTION) << "Failed, kernel_build_server_dir input is not a str";
2025   }
2026   auto kernel_build_server_dir_s = py::cast<std::string>(kernel_build_server_dir);
2027   auto ms_context = MsContext::GetInstance();
2028   ms_context->set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, kernel_build_server_dir_s);
2029 }
2030 
InitExecDataset(const std::string & queue_name,int64_t iter_num,int64_t batch_size,const std::vector<TypePtr> & types,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int64_t> & input_indexes,const std::string &,bool need_run)2031 bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
2032                      const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
2033                      const std::vector<int64_t> &input_indexes, const std::string &, bool need_run) {
2034   auto ms_context = MsContext::GetInstance();
2035   MS_EXCEPTION_IF_NULL(ms_context);
2036   std::string name = ms_context->backend_policy();
2037 #ifdef WITH_BACKEND
2038   if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
2039     auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
2040       {kAscendDevice, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
2041     MS_EXCEPTION_IF_NULL(device_context);
2042     MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
2043     if (!device_context->GetDeprecatedInterface()->IsTsdOpened(ms_context)) {
2044       InitPipeline();
2045     }
2046   }
2047 #endif
2048 
2049   if (name == kMsConvert || name == kMsVm || name == "ge") {
2050 #ifdef WITH_BACKEND
2051     if (iter_num == -1) {
2052       iter_num = INT32_MAX;
2053     }
2054     bool status = InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run);
2055     return status;
2056 #endif
2057   }
2058   return name == "ge" ? true : false;
2059 }
2060 
InitExecDatasetVm(const std::string & queue_name,int64_t size,int64_t batch_size,const std::vector<TypePtr> & types,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int64_t> & input_indexes,bool need_run)2061 bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
2062                        const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
2063                        const std::vector<int64_t> &input_indexes, bool need_run) {
2064 #if defined(__linux__) && defined(WITH_BACKEND)
2065   if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->cache_enable() &&
2066       !ps::PSContext::instance()->is_worker()) {
2067     return true;
2068   }
2069 #endif
2070   MS_LOG(INFO) << "Start InitDataSet Entry";
2071   mindspore::python_adapter::set_python_env_flag(true);
2072   ShapeVector int_input_indexes;
2073   (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes),
2074                        [](int64_t item) { return static_cast<int64_t>(item); });
2075   std::vector<ShapeVector> int_shapes;
2076   (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes),
2077                        [](const std::vector<int64_t> &item) {
2078                          ShapeVector vector_item;
2079                          (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item),
2080                                               [](int64_t inner_item) { return static_cast<int64_t>(inner_item); });
2081                          return vector_item;
2082                        });
2083   auto p_init = std::make_shared<Primitive>("InitDataSetQueue");
2084   p_init->set_attr("queue_name", MakeValue(queue_name));
2085   p_init->set_attr("size", MakeValue(static_cast<int64_t>(size)));
2086   p_init->set_attr("batch_size", MakeValue(static_cast<int64_t>(batch_size)));
2087   p_init->set_attr("types", MakeValue(types));
2088   p_init->set_attr("shapes", MakeValue(int_shapes));
2089   p_init->set_attr("input_indexes", MakeValue(int_input_indexes));
2090 
2091   const std::vector<std::string> empty_str_list;
2092   p_init->set_attr("input_names", MakeValue(empty_str_list));
2093   p_init->set_attr("output_names", MakeValue(empty_str_list));
2094 
2095   FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
2096   auto app_init = std::make_shared<CNode>(AnfNodeWeakPtrList({NewValueNode(p_init)}), func_graph);
2097   func_graph->set_output(app_init);
2098   auto manager = MakeManager();
2099   manager->AddFuncGraph(func_graph);
2100 
2101   // AbstractNone indicates there is no output for this apply node.
2102   auto abstract_none = std::make_shared<abstract::AbstractNone>();
2103   app_init->set_abstract(abstract_none);
2104   // Before the graph compiling, need reset the iter num.
2105   ConfigManager::GetInstance().ResetIterNum();
2106 #ifdef ENABLE_DUMP_IR
2107   mindspore::RDR::ResetRecorder();
2108 #endif
2109 
2110   compile::SetMindRTEnable();
2111   auto backend = compile::CreateBackend();
2112   MS_EXCEPTION_IF_NULL(backend);
2113   auto context_ptr = MsContext::GetInstance();
2114   MS_EXCEPTION_IF_NULL(context_ptr);
2115   // The data set graph compiling and running of mindRT.
2116   if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
2117 #if defined(__linux__) && defined(WITH_BACKEND)
2118     if (ps::PSContext::instance()->is_worker() && ps::PSContext::instance()->cache_enable()) {
2119       distributed::DataQueueManager::GetInstance().CreateDataQueue(queue_name, size, 128);
2120     }
2121 #endif
2122 
2123     const auto &mindrt_backend = std::dynamic_pointer_cast<compile::MindRTBackend>(backend);
2124     MS_EXCEPTION_IF_NULL(mindrt_backend);
2125     SetRunMode(func_graph, mindrt_backend.get());
2126     auto &actor_info = mindrt_backend->CompileGraphs(func_graph);
2127     VectorRef args;
2128     if (need_run) {
2129       VectorRef outputs;
2130       mindrt_backend->RunGraph(actor_info, args, &outputs);
2131     }
2132     ConfigManager::GetInstance().set_iter_num(queue_name, size);
2133     return true;
2134   }
2135 
2136   auto convert_fn = backend->convert_fn();
2137   MS_EXCEPTION_IF_NULL(convert_fn);
2138   // Convert CNodeList to LinConvertResult.
2139   auto segment = std::make_shared<GraphSegment>(std::vector<AnfNodePtr>{app_init}, false);
2140   auto runner = convert_fn(segment, "");
2141   ConfigManager::GetInstance().set_iter_num(queue_name, size);
2142 
2143   if (!(*runner.run)) {
2144     // empty function
2145     MS_LOG(EXCEPTION) << "Backend " << backend->name() << " unsupported tdt dataset.";
2146   }
2147 
2148   // launch init dataset runner without inputs and outputs
2149   VectorRef args;
2150   auto fn = runner.run;
2151   if (need_run) {
2152     (void)(*fn)(args);
2153   }
2154   MS_LOG(DEBUG) << "InitDataSetVm End.";
2155   return true;
2156 }
2157 
GetJitLevel()2158 std::string GetJitLevel() {
2159   const auto &jit_config = PhaseManager::GetInstance().jit_config();
2160   auto iter = jit_config.find("jit_level");
2161   if (iter != jit_config.end()) {
2162     return iter->second;
2163   }
2164   return "";
2165 }
2166 
ResetOpId()2167 void ResetOpId() { mindspore::id_generator::reset_id(); }
ResetOpIdWithOffset()2168 void ResetOpIdWithOffset() { mindspore::id_generator::reset_id_with_offset(); }
2169 
InitHccl()2170 void InitHccl() {
2171   auto ms_context = MsContext::GetInstance();
2172   MS_EXCEPTION_IF_NULL(ms_context);
2173   ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
2174 #ifdef WITH_BACKEND
2175   auto backend = ms_context->backend_policy();
2176   if (backend == "ge") {
2177     if (!mindspore::distributed::Initialize()) {
2178       MS_LOG(EXCEPTION) << "InitHccl failed.";
2179     }
2180     InitPipeline();
2181     return;
2182   }
2183 #endif
2184   mindspore::python_adapter::set_python_env_flag(true);
2185   std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2186   if (ms_context->backend_policy() == "ms" && device_name == kAscendDevice) {
2187     if (!mindspore::distributed::Initialize()) {
2188       MS_LOG(EXCEPTION) << "InitHccl failed.";
2189     }
2190   }
2191 }
2192 
FinalizeHccl()2193 void FinalizeHccl() {
2194   auto ms_context = MsContext::GetInstance();
2195   MS_EXCEPTION_IF_NULL(ms_context);
2196 #ifdef WITH_BACKEND
2197   auto backend = ms_context->backend_policy();
2198   if (backend == "ge") {
2199     FinalizeBackend();
2200     return;
2201   }
2202 #endif
2203   session::ExecutorManager::Instance().Clear();
2204   device::KernelRuntimeManager::Instance().ClearRuntimeResource();
2205   device::DeviceContextManager::GetInstance().ClearDeviceContexts();
2206   device::DeviceContextManager::GetInstance().UnloadPlugin();
2207 }
2208 
GetHcclRankId()2209 uint32_t GetHcclRankId() {
2210   uint32_t rank_id = 0;
2211   bool ret = CommManager::GetInstance().GetRankID("", &rank_id);
2212   if (!ret) {
2213     MS_LOG(ERROR) << "Get rank id failed, return rank id " << rank_id << " as default.";
2214   }
2215   return rank_id;
2216 }
2217 
GetHcclRankSize()2218 uint32_t GetHcclRankSize() {
2219   uint32_t rank_size = 0;
2220   bool ret = CommManager::GetInstance().GetRankSize("", &rank_size);
2221   if (!ret) {
2222     MS_LOG(ERROR) << "Get rank size failed, return rank size " << rank_size << " as default.";
2223   }
2224   return rank_size;
2225 }
2226 
ExportGraph(const std::string & file_name,const std::string & phase,const py::object encrypt,char * key)2227 void GraphExecutorPy::ExportGraph(const std::string &file_name, const std::string &phase, const py::object encrypt,
2228                                   char *key) {
2229   DeviceContext *device_context = nullptr;
2230   try {
2231     auto ms_context = MsContext::GetInstance();
2232     MS_EXCEPTION_IF_NULL(ms_context);
2233     auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
2234     device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({kAscendDevice, device_id});
2235   } catch (const std::exception &) {
2236     MS_EXCEPTION(ValueError) << "Only support export file in 'AIR' format with Ascend backend.";
2237   }
2238   MS_EXCEPTION_IF_NULL(device_context);
2239   MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
2240   FuncGraphPtr func_graph = info_[phase]->func_graph;
2241   MS_EXCEPTION_IF_NULL(func_graph);
2242   device_context->GetDeprecatedInterface()->ExportDFGraph(file_name, func_graph->ToString(), encrypt, key);
2243 }
2244 
LoadMindIR(const std::string & file_name,const char * dec_key,const size_t key_len,const std::string & dec_mode,const py::object decrypt,const bool obfuscated)2245 FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len,
2246                         const std::string &dec_mode, const py::object decrypt, const bool obfuscated) {
2247   if (obfuscated) {
2248     MS_LOG(DEBUG) << "[LoadMindIR] Set customized function.";
2249     (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
2250     (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
2251   }
2252   FuncGraphPtr func_graph = nullptr;
2253   if (dec_mode == "Customized") {
2254     py::bytes key_bytes(dec_key);
2255     py::bytes model_stream = decrypt(file_name, key_bytes);
2256     std::string model_string(model_stream);
2257 
2258     MindIRLoader mindir_loader;
2259     func_graph = mindir_loader.LoadMindIR(model_string.c_str(), model_string.size());
2260   } else {
2261     MindIRLoader mindir_loader(false, reinterpret_cast<const unsigned char *>(dec_key), key_len, dec_mode, false);
2262     func_graph = mindir_loader.LoadMindIR(file_name);
2263   }
2264 #ifdef ENABLE_DUMP_IR
2265   auto context = MsContext::GetInstance();
2266   MS_EXCEPTION_IF_NULL(context);
2267   if (context->CanDump(kIntroductory)) {
2268     DumpIR("load.ir", func_graph);
2269   }
2270 #endif
2271   return func_graph;
2272 }
2273 
SplitMindIR(const std::string & file_name)2274 FuncGraphPtr SplitMindIR(const std::string &file_name) {
2275   MS_LOG(INFO) << "Start split mindir";
2276   FuncGraphPtr func_graph = nullptr;
2277   MindIRLoader mindir_loader;
2278   func_graph = mindir_loader.LoadMindIR(file_name);
2279   if (func_graph == nullptr) {
2280     MS_LOG(ERROR) << "Load MindIR file failed. Please check model file.";
2281     return nullptr;
2282   }
2283 #ifdef ENABLE_DUMP_IR
2284   auto context = MsContext::GetInstance();
2285   MS_EXCEPTION_IF_NULL(context);
2286   if (context->CanDump(kIntroductory)) {
2287     DumpIR("load.ir", func_graph);
2288   }
2289 #endif
2290   auto ms_context = MsContext::GetInstance();
2291   MS_EXCEPTION_IF_NULL(ms_context);
2292   auto parallel_context = parallel::ParallelContext::GetInstance();
2293   parallel_context->Reset();
2294   parallel_context->set_parallel_mode(parallel::kAutoParallel);
2295   parallel_context->set_strategy_search_mode(parallel::kRecursiveProgramming);
2296   parallel_context->set_direct_split(true);
2297   parallel_context->set_full_batch(true);
2298   parallel_context->set_group_ckpt_save_file("group_info");
2299 
2300   FuncGraphManagerPtr func_graph_manager = func_graph->manager();
2301 
2302   MS_LOG(INFO) << "func_graph_manager is not null";
2303   if (func_graph_manager == nullptr) {
2304     std::vector<FuncGraphPtr> graphs{func_graph};
2305     func_graph_manager = std::make_shared<FuncGraphManager>(graphs);
2306     func_graph_manager->AddFuncGraph(func_graph);
2307   }
2308   pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
2309   resource->set_manager(func_graph_manager);
2310 
2311   // Get the parameters items and add the value to args_abs.
2312   auto params = func_graph->parameters();
2313   auto inputs = func_graph->get_inputs();
2314   for (std::size_t i = 0; i < inputs.size(); i++) {
2315     auto input = inputs[i]->abstract();
2316     (void)parallel::ExtendInputArgsAbstractShape(input, i);
2317   }
2318   parallel::StepAutoParallel(func_graph, NULL);
2319   parallel::StepParallel(func_graph, NULL);
2320   parallel::StepAllreduceFusion(func_graph, NULL);
2321   resource->set_func_graph(func_graph);
2322   resource->set_manager(func_graph->manager());
2323   opt::irpass::OptimizeIRPassLib irpass;
2324   opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
2325   opt::OptPassConfig virtual_output = opt::OptPassConfig({irpass.virtual_output_eliminate_});
2326 
2327   opt::OptPassGroupMap map_parallel_eliminate(
2328     {{"virtual_dataset", virtual_dataset}, {"virtual_output", virtual_output}});
2329 
2330   auto split_pass_opts = opt::Optimizer::MakeOptimizer("map_parallel_eliminate", resource, map_parallel_eliminate);
2331   ProfileExecute(MsProfile::GetProfile()->Step("split_pass_opts"),
2332                  [&split_pass_opts, &func_graph]() { func_graph = split_pass_opts->step(func_graph, true); });
2333 
2334   AbstractBasePtrList args_abs_list;
2335   (void)std::transform(params.begin(), params.end(), std::back_inserter(args_abs_list),
2336                        [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
2337   func_graph = pipeline::Renormalize(resource, func_graph, args_abs_list);
2338 
2339   resource->set_args_abs(args_abs_list);
2340 
2341   MindIRExporter mindir_exporter;
2342   mindir_exporter.ExportProto(func_graph, "split_net", nullptr);
2343 
2344   parallel::HandleGroupInfo();
2345 
2346   return func_graph;
2347 }
2348 
SplitDynamicMindIR(const std::string & file_name,size_t device_num,size_t rank_id,bool sapp)2349 FuncGraphPtr SplitDynamicMindIR(const std::string &file_name, size_t device_num, size_t rank_id, bool sapp) {
2350   MS_LOG(INFO) << "Start split dynamic mindir for transformer network";
2351   FuncGraphPtr func_graph = nullptr;
2352   MindIRLoader mindir_loader;
2353   func_graph = mindir_loader.LoadMindIR(file_name);
2354   if (func_graph == nullptr) {
2355     MS_LOG(ERROR) << "Load MindIR file failed. Please check model file.";
2356     return nullptr;
2357   }
2358 #ifdef ENABLE_DUMP_IR
2359   auto context = MsContext::GetInstance();
2360   MS_EXCEPTION_IF_NULL(context);
2361   if (context->CanDump(kIntroductory)) {
2362     DumpIR("load.ir", func_graph);
2363   }
2364 #endif
2365   auto ms_context = MsContext::GetInstance();
2366   MS_EXCEPTION_IF_NULL(ms_context);
2367   auto parallel_context = parallel::ParallelContext::GetInstance();
2368   parallel_context->Reset();
2369   parallel_context->set_parallel_mode(parallel::kAutoParallel);
2370   parallel_context->set_strategy_search_mode(parallel::kRecursiveProgramming);
2371   parallel_context->set_direct_split(true);
2372   parallel_context->set_full_batch(true);
2373   parallel_context->set_group_ckpt_save_file("group_info");
2374 
2375   for (size_t rank_id_iter = 0; rank_id_iter < device_num; rank_id_iter++) {
2376     auto tmp_func_graph = mindspore::BasicClone(func_graph);
2377     FuncGraphManagerPtr func_graph_manager = tmp_func_graph->manager();
2378 
2379     if (func_graph_manager == nullptr) {
2380       MS_LOG(INFO) << "func_graph_manager is null";
2381       std::vector<FuncGraphPtr> graphs{tmp_func_graph};
2382       func_graph_manager = std::make_shared<FuncGraphManager>(graphs);
2383       func_graph_manager->AddFuncGraph(tmp_func_graph);
2384     }
2385 
2386     auto inputs = tmp_func_graph->get_inputs();
2387     for (std::size_t i = 0; i < inputs.size(); i++) {
2388       auto input = inputs[i]->abstract();
2389       (void)parallel::ExtendInputArgsAbstractShape(input, i);
2390     }
2391 
2392     auto res = parallel::StepAssignedParallel(tmp_func_graph, func_graph_manager, device_num, rank_id_iter, sapp);
2393     if (!res) {
2394       MS_LOG(ERROR) << "StepAssignedParallel failed. Please check.";
2395       return nullptr;
2396     }
2397     pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
2398     resource->set_is_load(false);
2399     resource->set_manager(func_graph_manager);
2400     resource->set_func_graph(tmp_func_graph);
2401     // Get the parameters items and add the value to args_abs.
2402     auto params = tmp_func_graph->parameters();
2403     AbstractBasePtrList args_abs_list;
2404     (void)std::transform(params.begin(), params.end(), std::back_inserter(args_abs_list),
2405                          [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
2406     tmp_func_graph = pipeline::Renormalize(resource, tmp_func_graph, args_abs_list);
2407 
2408 #ifdef ENABLE_DUMP_IR
2409     auto re_context = MsContext::GetInstance();
2410     MS_EXCEPTION_IF_NULL(re_context);
2411     if (re_context->CanDump(kIntroductory)) {
2412       string renormalize_net_name = "Renomalize_" + std::to_string(rank_id_iter) + ".ir";
2413       DumpIR(renormalize_net_name, tmp_func_graph);
2414     }
2415 #endif
2416 
2417     parallel::HandleGroupInfo();
2418     string net_save_name = "split_net" + std::to_string(rank_id_iter);
2419     MindIRExporter mindir_exporter;
2420     res = mindir_exporter.ExportProto(tmp_func_graph, net_save_name, nullptr);
2421     if (!res) {
2422       MS_LOG(ERROR) << "Export MindIR file failed failed. Please check.";
2423       return nullptr;
2424     }
2425   }
2426 
2427   return func_graph;
2428 }
2429 
DynamicObfuscateMindIR(const std::string & file_name,float obf_ratio,int branch_control_input,char * dec_key,const size_t key_len,const std::string & dec_mode)2430 FuncGraphPtr DynamicObfuscateMindIR(const std::string &file_name, float obf_ratio, int branch_control_input,
2431                                     char *dec_key, const size_t key_len, const std::string &dec_mode) {
2432   if (branch_control_input == 0) {
2433     (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
2434     MS_LOG(DEBUG) << "[DynamicObfuscateMindIR] set function names finished.";
2435   }
2436   mindspore::DynamicObfuscator dynamic_obfuscator(obf_ratio, branch_control_input);
2437   MindIRLoader mindir_loader(false, reinterpret_cast<unsigned char *>(dec_key), key_len, dec_mode, false);
2438   FuncGraphPtr func_graph = mindir_loader.LoadMindIR(file_name);
2439   ModifyGraphs(func_graph);
2440   auto manager = func_graph->manager();
2441   if (manager == nullptr) {
2442     manager = MakeManager();
2443     manager->AddFuncGraph(func_graph, true);
2444   }
2445   InferFuncGraphLoaded(func_graph);
2446   if (func_graph == nullptr) {
2447     MS_LOG(EXCEPTION) << "[DynamicObfuscateMindIR] load mindir failed, please check the mindir file.";
2448     return nullptr;
2449   }
2450   mindspore::FuncGraphPtr obfuscated_graph = dynamic_obfuscator.ObfuscateMindIR(func_graph);
2451   if (obfuscated_graph == nullptr) {
2452     MS_LOG(ERROR) << "[DynamicObfuscateMindIR] obfuscate model failed.";
2453     return nullptr;
2454   }
2455   return obfuscated_graph;
2456 }
2457 
CloseTsd(bool force)2458 void CloseTsd(bool force) {
2459 #ifdef WITH_BACKEND
2460   auto context_ptr = MsContext::GetInstance();
2461   MS_EXCEPTION_IF_NULL(context_ptr);
2462   if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
2463     const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
2464       {kAscendDevice, context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
2465     MS_EXCEPTION_IF_NULL(device_context);
2466     MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
2467     (void)device_context->GetDeprecatedInterface()->CloseTsd(context_ptr, force);
2468   }
2469 #endif
2470 }
2471 
InitPipeline()2472 void InitPipeline() {
2473   // set python env flag
2474   RecordInitStatus();
2475   mindspore::python_adapter::set_python_env_flag(true);
2476   auto ms_context = MsContext::GetInstance();
2477   MS_EXCEPTION_IF_NULL(ms_context);
2478   CompileConfigManager::GetInstance().CollectCompileConfig();
2479 #ifdef WITH_BACKEND
2480   auto backend = ms_context->backend_policy();
2481   auto device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2482   if (backend == "ge") {
2483     const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
2484       {device_name, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
2485     MS_EXCEPTION_IF_NULL(device_context);
2486     device_context->Initialize();
2487   }
2488   if (!common::UseDynamicCluster()) {
2489     if (device_name == kAscendDevice) {
2490       const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
2491         {device_name, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
2492       MS_EXCEPTION_IF_NULL(device_context);
2493       MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
2494       if (!device_context->GetDeprecatedInterface()->OpenTsd(ms_context)) {
2495         MS_LOG(EXCEPTION) << "Open tsd failed";
2496       }
2497     }
2498   }
2499 #endif
2500 }
2501 
FinalizeBackend()2502 void FinalizeBackend() { CloseTsd(); }
2503 
MemoryRecycle()2504 void MemoryRecycle() {
2505 #ifdef ENABLE_DUMP_IR
2506   mindspore::RDR::ResetRecorder();
2507 #endif
2508   ReclaimOptimizer();
2509   session::ExecutorManager::Instance().ClearDoneTasks();
2510   ad::g_k_prims.clear();
2511   ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
2512   abstract::AnalysisResultCacheMgr::GetInstance().Clear();
2513   abstract::AnalysisContext::ClearContext();
2514   kArgsCache.clear();
2515   kCellArgsMap.clear();
2516   // clean static variable to prevent from crash. As static variable is released after
2517   // Python threads is released.
2518   parse::data_converter::ClearObjectCache();
2519   parse::Parser::CleanParserResource();
2520   trace::ClearTraceStack();
2521   pynative::PyNativeExecutor::GetInstance()->ClearRes();
2522   ConfigManager::GetInstance().ResetConfig();
2523   ScopeManager::GetInstance().ClearScope();
2524   FuncGraphLoopBreaker::Inst().CleanMetaFuncGraphs();
2525   FuncGraphLoopBreaker::Inst().BreakLoop();
2526 }
2527 
BindDeviceCtx()2528 void BindDeviceCtx() { device::DeviceContextManager::GetInstance().BindDeviceCtx(); }
2529 
ClearResPart1()2530 void ClearResPart1() {
2531   pynative::PyNativeExecutor::GetInstance()->WorkerJoin();
2532   runtime::OpExecutor::GetInstance().WorkerJoin();
2533   // When the python process exits, the kernels on the device may not have finished executing.
2534   device::KernelRuntimeManager::Instance().WaitTaskFinishOnDevice();
2535   device::DeviceContextManager::GetInstance().WaitTaskFinishOnDevice();
2536 
2537   RecordExitStatus();
2538 #ifdef ENABLE_DUMP_IR
2539   mindspore::RDR::Snapshot();
2540   mindspore::RDR::ResetRecorder();
2541 #endif
2542   runtime::GraphScheduler::GetInstance().Clear();
2543   runtime::ProfilerAnalyzer::GetInstance().Clear();
2544 
2545   auto ms_context = MsContext::GetInstance();
2546   MS_EXCEPTION_IF_NULL(ms_context);
2547   if (ms_context->backend_policy() != "ge") {
2548     // clear runtime resource before destroy hccl comm
2549     MS_LOG(INFO) << "Start clear kernel runtime...";
2550     device::KernelRuntimeManager::Instance().ClearRuntimeResource();
2551     MS_LOG(INFO) << "End clear kernel runtime.";
2552   }
2553 
2554   MS_LOG(INFO) << "Start Finalize StreamSynchronizer...";
2555   device::StreamSynchronizer::GetInstance()->Finalize();
2556   MS_LOG(INFO) << "End Finalize StreamSynchronizer...";
2557 
2558   PrimitivePy::ClearHookRes();
2559   ad::g_k_prims.clear();
2560   ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
2561 
2562   abstract::ClearPrimEvaluatorMap();
2563   pipeline::GetMethodMap().clear();
2564   pipeline::GetAttrMap().clear();
2565   pipeline::GraphExecutorPy::ClearRes();
2566   pipeline::ReclaimOptimizer();
2567 }
2568 
ClearResPart2()2569 void ClearResPart2() {
2570   MS_LOG(INFO) << "Start clear PyNativeExecutor...";
2571   pynative::PyNativeExecutor::GetInstance()->ClearRes();
2572   MS_LOG(INFO) << "End clear PyNativeExecutor.";
2573 
2574 #ifdef WITH_BACKEND
2575   auto ms_context = MsContext::GetInstance();
2576   MS_EXCEPTION_IF_NULL(ms_context);
2577   if (ms_context->backend_policy() == "ge") {
2578     auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
2579     DeviceContext *device_context =
2580       device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({kAscendDevice, device_id});
2581     MS_EXCEPTION_IF_NULL(device_context);
2582     MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
2583     device_context->GetDeprecatedInterface()->ClearGraphWrapper();
2584     device_context->GetDeprecatedInterface()->ClearOpAdapterMap();
2585     // unregister external allocator, before clear stream and graphrunner
2586     device_context->GetDeprecatedInterface()->UnregisterExternalAllocator();
2587     // clear runtime resource after clear graph when ge
2588     MS_LOG(INFO) << "Start clear kernel runtime...";
2589     device::KernelRuntimeManager::Instance().ClearRuntimeResource();
2590     MS_LOG(INFO) << "End clear kernel runtime.";
2591   } else {
2592     MS_LOG(INFO) << "Start clear ConfigManager...";
2593     ConfigManager::GetInstance().ResetIterNum();
2594     MS_LOG(INFO) << "End clear ConfigManager.";
2595   }
2596 #else
2597   MS_LOG(INFO) << "Start clear ConfigManager...";
2598   ConfigManager::GetInstance().ResetIterNum();
2599   MS_LOG(INFO) << "End clear ConfigManager.";
2600 #endif
2601 
2602   session::ExecutorManager::Instance().Clear();
2603   // for GE, HcclCommDestroy should after RemoveGraph in ClearGraphWrapper
2604   (void)distributed::collective::CollectiveManager::instance()->Finalize();
2605 
2606   MS_LOG(INFO) << "Start clear device context...";
2607   device::DeviceContextManager::GetInstance().ClearDeviceContexts();
2608   MS_LOG(INFO) << "End clear device context.";
2609 
2610   MS_LOG(INFO) << "Start clear AnalysisResultCacheMgr...";
2611   abstract::AnalysisResultCacheMgr::GetInstance().Clear();
2612   MS_LOG(INFO) << "End clear AnalysisResultCacheMgr.";
2613 
2614   MS_LOG(INFO) << "Start clear AnalysisContext...";
2615   abstract::AnalysisContext::ClearContext();
2616   MS_LOG(INFO) << "End clear AnalysisContext...";
2617 
2618   MS_LOG(INFO) << "Start clear AnalysisSchedule...";
2619   abstract::AnalysisSchedule::GetInstance().Stop();
2620   MS_LOG(INFO) << "End clear AnalysisSchedule...";
2621 #ifdef ENABLE_DEBUGGER
2622   auto debugger = Debugger::GetInstance();
2623   MS_EXCEPTION_IF_NULL(debugger);
2624   debugger->Reset();
2625 #endif
2626   kArgsCache.clear();
2627   kCellArgsMap.clear();
2628 }
2629 
ClearResPart3()2630 void ClearResPart3() {
2631   // clean static variable to prevent from crash. As static variable is released after
2632   // Python threads is released.
2633   MS_LOG(INFO) << "Start clear ClearObjectCache...";
2634   parse::data_converter::ClearObjectCache();
2635   MS_LOG(INFO) << "End clear ClearObjectCache...";
2636 
2637   MS_LOG(INFO) << "Start clear Parser...";
2638   parse::Parser::CleanParserResource();
2639   MS_LOG(INFO) << "End clear Parser...";
2640 
2641   MS_LOG(INFO) << "Start ClearTraceStack...";
2642   trace::ClearTraceStack();
2643   MS_LOG(INFO) << "End ClearTraceStack...";
2644 
2645   MS_LOG(INFO) << "Start clear InterpretNodeRecorder...";
2646   InterpretNodeRecorder::GetInstance().Clear();
2647   MS_LOG(INFO) << "End clear InterpretNodeRecorder...";
2648 
2649   MS_LOG(INFO) << "Start clear parallel::entire_costgraph...";
2650   parallel::entire_costgraph.reset();
2651   MS_LOG(INFO) << "End clear parallel::entire_costgraph...";
2652 
2653   MS_LOG(INFO) << "Start clear ProtobufLibrary...";
2654   google::protobuf::ShutdownProtobufLibrary();
2655   MS_LOG(INFO) << "End clear ProtobufLibrary...";
2656   // ResetPythonScope after all py::object is freed.
2657   MS_LOG(INFO) << "Start clear python_adapter...";
2658   python_adapter::ResetPythonScope();
2659   MS_LOG(INFO) << "End clear python_adapter.";
2660 }
2661 
ClearSingleton()2662 void ClearSingleton() {
2663   MS_LOG(INFO) << "Start clear singleton...";
2664   profiler::Profiler::Clear();
2665 #ifdef ENABLE_AKG
2666   kernel::GraphKernelBuildManager::Instance().Clear();
2667 #endif
2668   somas::SomasManager::Instance().Clear();
2669   GraphKernelInfoManager::Instance().Clear();
2670   device::DataQueueMgr::GetInstance().Clear();
2671   session::SessionFactory::Get().Clear();
2672   device::KernelRuntimeManager::Instance().Clear();
2673   OpPrimPyRegister::GetInstance().Clear();
2674 #ifndef ENABLE_SECURITY
2675   DumpJsonParser::Finalize();
2676   AclDumpJsonWriter::Finalize();
2677 #endif
2678   CommManager::Clear();
2679   expander::ClearAllCache();
2680   MS_LOG(INFO) << "End clear singleton.";
2681 }
2682 
ClearResAtexit()2683 void ClearResAtexit() {
2684   MS_LOG(INFO) << "Pipeline clear all resource";
2685   try {
2686     MsException::Instance().CheckException();
2687   } catch (const std::exception &e) {
2688     MS_LOG(ERROR) << "Check exception before process exit: " << e.what();
2689   }
2690   ClearResPart1();
2691   ClearResPart2();
2692 
2693   mindspore::trans::FormatHelper::GetInstance().Clear();
2694   ClearResPart3();
2695   ClearSingleton();
2696   MS_LOG(INFO) << "Start unload dynamic lib...";
2697   device::DeviceContextManager::GetInstance().UnloadPlugin();
2698   MS_LOG(INFO) << "End unload dynamic lib...";
2699 }
2700 
PyEncrypt(char * plain_data,size_t plain_len,char * key,size_t key_len,const std::string & enc_mode)2701 py::bytes PyEncrypt(char *plain_data, size_t plain_len, char *key, size_t key_len, const std::string &enc_mode) {
2702   size_t encrypt_len;
2703   auto encrypt_data = mindspore::Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
2704                                          reinterpret_cast<Byte *>(key), key_len, enc_mode);
2705   if (encrypt_data == nullptr) {
2706     MS_EXCEPTION(ValueError) << "Encrypt failed";
2707   }
2708   auto py_encrypt_data = py::bytes(reinterpret_cast<char *>(encrypt_data.get()), encrypt_len);
2709   return py_encrypt_data;
2710 }
2711 
PyDecrypt(const std::string & encrypt_data_path,char * key,size_t key_len,const std::string & dec_mode)2712 py::bytes PyDecrypt(const std::string &encrypt_data_path, char *key, size_t key_len, const std::string &dec_mode) {
2713   size_t decrypt_len;
2714   auto decrypt_data =
2715     mindspore::Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode);
2716   if (decrypt_data == nullptr) {
2717     MS_LOG(ERROR) << "Decrypt failed";
2718     return py::none();
2719   }
2720   auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len);
2721   return py_decrypt_data;
2722 }
2723 
PyDecryptData(char * model_data,size_t data_size,char * key,size_t key_len,const std::string & dec_mode)2724 py::bytes PyDecryptData(char *model_data, size_t data_size, char *key, size_t key_len, const std::string &dec_mode) {
2725   size_t decrypt_len;
2726   auto decrypt_data = mindspore::Decrypt(&decrypt_len, reinterpret_cast<Byte *>(model_data), data_size,
2727                                          reinterpret_cast<Byte *>(key), key_len, dec_mode);
2728   if (decrypt_data == nullptr) {
2729     MS_LOG(ERROR) << "Decrypt failed";
2730     return py::none();
2731   }
2732   auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len);
2733   return py_decrypt_data;
2734 }
2735 
PyIsCipherFile(const std::string & file_path)2736 bool PyIsCipherFile(const std::string &file_path) { return mindspore::IsCipherFile(file_path); }
2737 
FinalizeCluster()2738 void FinalizeCluster() {
2739 #if defined(__linux__) && defined(WITH_BACKEND)
2740   if (distributed::cluster::ClusterContext::instance()->initialized()) {
2741     if (!distributed::cluster_exit_with_exception()) {
2742       MS_LOG(INFO) << "Start finalize the cluster instance.";
2743       // Finalize MindSpore cluster only when this process exits without any exception.
2744       (void)distributed::cluster::ClusterContext::instance()->Finalize(UINT32_MAX);
2745       MS_LOG(INFO) << "End finalize the cluster instance.";
2746     }
2747   }
2748 #endif
2749 }
2750 
SwapCache(const tensor::TensorPtr & host,const tensor::TensorPtr & device,const tensor::TensorPtr & block_mapping,const bool & is_device_to_host)2751 void SwapCache(const tensor::TensorPtr &host, const tensor::TensorPtr &device, const tensor::TensorPtr &block_mapping,
2752                const bool &is_device_to_host) {
2753   auto block_mapping_shape = block_mapping->shape();
2754   if (block_mapping_shape.size() != 2) {
2755     MS_LOG_EXCEPTION << "The shape size of Cache input mapping tensor should be 2, but got: "
2756                      << block_mapping_shape.size();
2757   }
2758   if (block_mapping_shape[1] != 2) {
2759     MS_LOG_EXCEPTION << "The second dim of CacheKernel input mapping tensor should be 2, but got: "
2760                      << block_mapping_shape[0];
2761   }
2762 
2763   auto in_shape = device->shape();
2764   auto type_byte = GetTypeByte(TypeIdToType(host->data_type()));
2765   size_t block_size_in_bytes = LongToSize(
2766     std::accumulate(in_shape.begin() + 1, in_shape.end(), SizeToLong(type_byte), std::multiplies<int64_t>()));
2767 
2768   uint8_t *host_ptr = reinterpret_cast<uint8_t *>(host->data_c());
2769   MS_EXCEPTION_IF_NULL(host_ptr);
2770   auto device_addr = std::dynamic_pointer_cast<device::DeviceAddress>(device->device_address());
2771   MS_EXCEPTION_IF_NULL(device_addr);
2772   uint8_t *device_ptr = reinterpret_cast<uint8_t *>(const_cast<void *>(device_addr->GetPtr()));
2773   MS_EXCEPTION_IF_NULL(device_ptr);
2774 
2775   auto block_mapping_data = reinterpret_cast<int64_t *>(block_mapping->data_c());
2776   for (int64_t i = 0; i < block_mapping_shape[0]; i++) {
2777     int64_t src_block_num = block_mapping_data[2 * i];
2778     int64_t dst_block_num = block_mapping_data[2 * i + 1];
2779     size_t src_block_offset = LongToSize(src_block_num) * block_size_in_bytes;
2780     size_t dst_block_offset = LongToSize(dst_block_num) * block_size_in_bytes;
2781 
2782     if (is_device_to_host) {
2783       device_addr->CopyDeviceToHost(host_ptr + dst_block_offset, device_ptr + src_block_offset, block_size_in_bytes);
2784     } else {
2785       device_addr->CopyHostToDevice(device_ptr + dst_block_offset, host_ptr + src_block_offset, block_size_in_bytes);
2786     }
2787   }
2788 }
2789 }  // namespace pipeline
2790 }  // namespace mindspore
2791