1 /**
2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
3 *
4 * Copyright 2019-2024 Huawei Technologies Co., Ltd
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19 #include "pipeline/jit/ps/pipeline.h"
20
21 #include <memory>
22 #include <map>
23 #include <cstdlib>
24 #include <algorithm>
25 #include <iomanip>
26 #include <unordered_map>
27 #include <functional>
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "pybind_api/pybind_patch.h"
30 #include "pybind11/pybind11.h"
31 #include "ir/param_info.h"
32 #include "pipeline/jit/ps/action.h"
33 #include "pipeline/jit/ps/pass.h"
34 #include "pipeline/jit/ps/parse/data_converter.h"
35 #include "pipeline/jit/ps/static_analysis/async_eval_result.h"
36 #include "pipeline/jit/ps/compile_cache_manager.h"
37 #include "pipeline/pynative/pynative_execute.h"
38 #include "frontend/optimizer/ad/dfunctor.h"
39 #include "frontend/optimizer/ad/prim_bprop_optimizer.h"
40 #include "include/common/utils/parallel_context.h"
41 #include "frontend/parallel/step_parallel_utils.h"
42 #include "frontend/parallel/parameter_manager.h"
43 #include "frontend/parallel/graph_util/get_parallel_info.h"
44 #include "frontend/parallel/graph_util/flops_collection.h"
45 #include "frontend/parallel/auto_parallel/graph_costmodel.h"
46 #include "frontend/parallel/step_auto_parallel.h"
47 #include "frontend/parallel/step_parallel.h"
48 #include "frontend/parallel/device_manager.h"
49 #include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
50 #include "frontend/parallel/pass/handle_group_info.h"
51 #include "frontend/parallel/step_assigned_parallel.h"
52 #include "frontend/parallel/dynamic_shape/dynamic_shape.h"
53 #include "frontend/expander/utils.h"
54 #include "include/common/utils/config_manager.h"
55 #include "include/common/utils/convert_utils.h"
56 #include "include/common/utils/convert_utils_py.h"
57 #include "include/common/utils/python_utils.h"
58 #include "utils/log_adapter.h"
59 #include "utils/ms_context.h"
60 #include "utils/shape_utils.h"
61 #include "utils/info.h"
62 #include "utils/crypto.h"
63 #include "utils/phase.h"
64 #include "utils/compile_config.h"
65 #include "include/common/utils/comm_manager.h"
66 #include "include/common/utils/stub_tensor.h"
67 #include "utils/interpret_node_recorder.h"
68 #include "include/common/debug/anf_ir_dump.h"
69 #include "include/common/debug/dump_proto.h"
70 #include "pipeline/jit/ps/fallback.h"
71 #include "pipeline/jit/ps/debug/trace.h"
72 #include "pipeline/jit/ps/event_message_print.h"
73 #include "include/common/debug/draw.h"
74 #include "include/common/debug/common.h"
75 #include "load_mindir/load_model.h"
76 #include "backend/graph_compiler/segment_runner.h"
77 #include "backend/common/session/executor_manager.h"
78 #include "backend/common/session/session_factory.h"
79 #include "runtime/hardware/device_context_manager.h"
80 #include "runtime/device/kernel_runtime_manager.h"
81 #include "runtime/pynative/op_executor.h"
82 #include "runtime/device/stream_synchronizer.h"
83 #include "include/common/fallback.h"
84 #include "include/common/profiler.h"
85 #include "include/backend/distributed/collective/collective_manager.h"
86 #include "include/backend/distributed/recovery/recovery_context.h"
87 #include "include/common/utils/dynamic_obfuscation/dynamic_obfuscation.h"
88 #include "include/common/utils/dynamic_obfuscation/registry_opaque_predicate.h"
89 #include "mindspore/ccsrc/plugin/device/cpu/kernel/pyexecute/py_execute_cpu_kernel.h"
90 #include "include/backend/distributed/init.h"
91 #include "include/backend/debug/profiler/profiling.h"
92 #include "kernel/graph_kernel/graph_kernel_builder_manager.h"
93 #include "kernel/graph_kernel_info.h"
94 #include "include/backend/data_queue/data_queue_mgr.h"
95 #include "mindspore/core/symbolic_shape/symbol_info.h"
96 #include "include/common/symbol_engine/symbol_engine_impl.h"
97 #include "pipeline/jit/ps/load_mindir.h"
98 #include "load_mindir/infer_mindir.h"
99
100 #ifndef ENABLE_SECURITY
101 #include "include/backend/debug/data_dump/dump_json_parser.h"
102 #include "include/backend/debug/data_dump/acl_dump_json_writer.h"
103 #include "abstract/abstract_value.h"
104 #endif
105 #if defined(__linux__) && defined(WITH_BACKEND)
106 #include "include/backend/distributed/ps/constants.h"
107 #include "include/backend/distributed/ps/util.h"
108 #include "include/backend/distributed/ps/ps_cache/ps_data_prefetch.h"
109 #include "include/backend/distributed/cluster/cluster_context.h"
110 #include "runtime/graph_scheduler/embedding_cache_scheduler.h"
111 #include "include/backend/distributed/ps/ps_context.h"
112 #include "include/backend/distributed/embedding_cache/data_queue_manager.h"
113 #endif
114 #ifdef ENABLE_DUMP_IR
115 #include "debug/rdr/graph_recorder.h"
116 #include "include/common/debug/rdr/recorder_manager.h"
117 #include "ir/cell.h"
118 #endif
119
120 #include "pybind_api/ir/log_adapter_py.h" // Only include one-time in the whole project.
121 #include "pybind_api/ir/py_execute_py.h" // Only include one-time in the whole project.
122 #include "include/common/utils/compile_cache_context.h"
123
124 namespace mindspore {
125 // namespace to support intermediate representation definition
126 namespace pipeline {
127 using Tensor = mindspore::tensor::Tensor;
128 using MetaTensor = mindspore::tensor::MetaTensor;
129 using MetaSparseTensor = mindspore::tensor::MetaSparseTensor;
130 using CSRTensor = mindspore::tensor::CSRTensor;
131 using COOTensor = mindspore::tensor::COOTensor;
132 using mindspore::abstract::AbstractTensor;
133 using mindspore::abstract::AbstractTensorPtr;
134 using mindspore::abstract::AbstractTuple;
135 using mindspore::abstract::AbstractTuplePtr;
136 using DeviceTensor = mindspore::device::DeviceAddress;
137
138 const char IR_TYPE_ANF[] = "anf_ir";
139 const char IR_TYPE_ONNX[] = "onnx_ir";
140 const char IR_TYPE_MINDIR[] = "mind_ir";
141
142 GraphExecutorPyPtr GraphExecutorPy::executor_ = nullptr;
143 std::mutex GraphExecutorPy::instance_lock_;
144
145 std::unordered_map<abstract::AbstractBasePtrList, uint64_t, abstract::AbstractBasePtrListHasher,
146 abstract::AbstractBasePtrListEqual>
147 kArgsCache;
148 std::unordered_map<PyObject *, abstract::AbstractBasePtrList> kCellArgsMap;
149
150 namespace {
151 #ifdef ENABLE_DUMP_IR
GetBaseNameForIR(int64_t stage_idx,const std::string & action_name)152 std::string GetBaseNameForIR(int64_t stage_idx, const std::string &action_name) {
153 std::ostringstream oss;
154 int spaces = 2;
155 oss << std::setfill('0') << std::setw(spaces) << stage_idx << "_" << action_name;
156 return oss.str();
157 }
158 #endif
159
CheckAllTensor(const ValueTuplePtr & value_tuple)160 bool CheckAllTensor(const ValueTuplePtr &value_tuple) {
161 auto elements = value_tuple->value();
162 for (auto element : elements) {
163 MS_EXCEPTION_IF_NULL(element);
164 if (!(element->isa<ValueTuple>() && CheckAllTensor(element->cast<ValueTuplePtr>())) &&
165 !(element->isa<MetaTensor>())) {
166 return false;
167 }
168 }
169 return true;
170 }
171
Mutable(const py::object & obj,const ValuePtr & value)172 bool Mutable(const py::object &obj, const ValuePtr &value) {
173 // If a tensor has been set const arg, it should not be mutable.
174 if (value->isa<MetaTensor>()) {
175 constexpr char const_arg_attr[] = "const_arg";
176 if (py::hasattr(obj, const_arg_attr) && py::cast<bool>(py::getattr(obj, const_arg_attr))) {
177 return false;
178 }
179 }
180 constexpr char mutable_attr[] = "__ms_mutable__";
181 return py::hasattr(obj, mutable_attr) && py::cast<bool>(py::getattr(obj, mutable_attr));
182 }
183
CheckAndConvertToVariableLenSequence(const py::object & obj,AbstractBasePtr abs)184 bool CheckAndConvertToVariableLenSequence(const py::object &obj, AbstractBasePtr abs) {
185 constexpr char variable_len_attr[] = "__ms_dynamic_len__";
186 bool dynamic_len = (py::hasattr(obj, variable_len_attr) && py::cast<bool>(py::getattr(obj, variable_len_attr)));
187 if (!dynamic_len) {
188 return false;
189 }
190 if (!abs->isa<abstract::AbstractSequence>()) {
191 MS_EXCEPTION(TypeError) << "For mutable, when the dynamic_len the True, the first input should be"
192 << " list or tuple, but got: " << abs->ToString();
193 }
194 auto abs_seq = abs->cast<abstract::AbstractSequencePtr>();
195 abs_seq->CheckAndConvertToDynamicLenSequence();
196 return true;
197 }
198
TensorArgMutable(const py::object & obj,const ValuePtr & value)199 bool TensorArgMutable(const py::object &obj, const ValuePtr &value) {
200 if (!value->isa<MetaTensor>()) {
201 return false;
202 }
203 constexpr char const_arg_attr[] = "const_arg";
204 return !py::hasattr(obj, const_arg_attr) || !py::cast<bool>(py::getattr(obj, const_arg_attr));
205 }
206
EnableTupleBroaden(const ValuePtr & value,bool enable_tuple_broaden)207 bool EnableTupleBroaden(const ValuePtr &value, bool enable_tuple_broaden) {
208 return enable_tuple_broaden && value->isa<ValueTuple>() && CheckAllTensor(value->cast<ValueTuplePtr>());
209 }
210
GradForScalar(const ValuePtr & value)211 bool GradForScalar(const ValuePtr &value) {
212 return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && value->isa<Scalar>();
213 }
214
ArgsToAbstract(const py::object & arg,const ValuePtr & value,bool enable_tuple_broaden=false)215 AbstractBasePtr ArgsToAbstract(const py::object &arg, const ValuePtr &value, bool enable_tuple_broaden = false) {
216 bool broaden = TensorArgMutable(arg, value) || Mutable(arg, value) || value->isa<MetaSparseTensor>() ||
217 EnableTupleBroaden(value, enable_tuple_broaden) || GradForScalar(value);
218 auto ret = abstract::ToAbstract(value, nullptr, nullptr);
219 if (broaden) {
220 ret = AbstractBroaden(ret);
221 }
222 auto is_dynamic_len = CheckAndConvertToVariableLenSequence(arg, ret);
223 if (fallback::EnableFallbackListDictInplace() && !broaden && !is_dynamic_len) {
224 // Attach corresponding list python object for constant list input.
225 fallback::AttachPyObjToAbs(ret, arg, false);
226 }
227 return ret;
228 }
229
CheckArgValid(const py::handle & arg)230 bool CheckArgValid(const py::handle &arg) {
231 if (py::isinstance<py::list>(arg) || py::isinstance<py::tuple>(arg)) {
232 auto vector_arg = py::cast<py::list>(arg);
233 return std::all_of(vector_arg.begin(), vector_arg.end(), CheckArgValid);
234 }
235
236 if (py::isinstance<py::dict>(arg)) {
237 auto dict_arg = py::cast<py::dict>(arg);
238 return std::all_of(dict_arg.begin(), dict_arg.end(), [](const auto &pair) { return CheckArgValid(pair.second); });
239 }
240
241 if (py::isinstance<Tensor>(arg) || IsStubTensor(arg)) {
242 auto tensor = IsStubTensor(arg) ? ConvertStubTensor(arg) : py::cast<TensorPtr>(arg);
243 if (tensor->data_type() == kNumberTypeBool) {
244 MS_LOG(INFO) << "It is not recommended to use a tensor of bool data type as network input, which may cause "
245 << "operator compilation failure. For more details, please refer to the FAQ at "
246 << "https://mindspore.cn/search?[AddN]%20input(kNumberTypeBool.";
247 }
248 }
249
250 return IsStubTensor(arg) || py::isinstance<py::int_>(arg) || py::isinstance<py::float_>(arg) ||
251 py::isinstance<py::none>(arg) || py::isinstance<Number>(arg) || py::isinstance<py::str>(arg) ||
252 py::isinstance<Tensor>(arg) || py::isinstance<CSRTensor>(arg) || py::isinstance<COOTensor>(arg);
253 }
254
GetCompileExceptionInfo()255 std::string GetCompileExceptionInfo() {
256 std::ostringstream oss;
257 trace::GetTraceStackInfo(oss);
258 return oss.str();
259 }
260
SetLoopCount(const ResourcePtr & resource)261 void SetLoopCount(const ResourcePtr &resource) {
262 MS_EXCEPTION_IF_NULL(resource);
263 auto func_graph = resource->func_graph();
264 if (func_graph != nullptr && func_graph->manager() != nullptr) {
265 auto manager = func_graph->manager();
266 size_t graph_nums = manager->func_graphs().size();
267 int64_t loop_size = ConfigManager::GetInstance().iter_num();
268 const auto context_ptr = MsContext::GetInstance();
269 bool enable_mind_rt = context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT);
270 if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
271 resource->set_vm_loop(!(context_ptr->get_param<bool>(MS_CTX_IS_MULTI_GRAPH_SINK) || enable_mind_rt), loop_size);
272 } else if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kGPUDevice) {
273 bool run_with_mind_rt = graph_nums == 1 || enable_mind_rt;
274 resource->set_vm_loop(!run_with_mind_rt, loop_size);
275 }
276 MS_LOG(INFO) << "Change vm_loop_flag to " << resource->vm_loop_flag() << ", set loop_size to " << loop_size;
277 }
278 }
279
GenerateJitConfigMap(const py::dict & jit_config)280 std::map<string, string> GenerateJitConfigMap(const py::dict &jit_config) {
281 std::map<string, string> ret{};
282 for (auto jit_param = jit_config.begin(); jit_param != jit_config.end(); ++jit_param) {
283 auto param_name = py::cast<std::string>(jit_param->first);
284 auto param_value = py::cast<std::string>(jit_param->second);
285 ret[param_name] = param_value;
286 }
287 return ret;
288 }
289
RecordInitStatus()290 void RecordInitStatus() {
291 static bool printed = false;
292 if (!printed) {
293 MS_LOG(INFO) << "Status record: system init.";
294 printed = true;
295 }
296 }
297
RecordExitStatus()298 void RecordExitStatus() { MS_LOG(INFO) << "Status record: system exit."; }
299
ToOrdinal(const size_t & i)300 std::string ToOrdinal(const size_t &i) {
301 auto suffix = "th";
302 if (i == kIndex1) {
303 suffix = "st";
304 } else if (i == kIndex2) {
305 suffix = "nd";
306 } else if (i == kIndex3) {
307 suffix = "rd";
308 }
309 return std::to_string(i) + suffix;
310 }
311
GetUserDataFromAddress(const py::object & res)312 kernel::PyExecuteOutputUserDataPtr GetUserDataFromAddress(const py::object &res) {
313 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
314 if (!allow_fallback_runtime) {
315 return nullptr;
316 }
317
318 if (py::isinstance<tensor::Tensor>(res) || IsStubTensor(res)) {
319 auto res_tensor = IsStubTensor(res) ? ConvertStubTensor(res) : res.cast<tensor::TensorPtr>();
320 MS_EXCEPTION_IF_NULL(res_tensor);
321 if (res_tensor->device_address() != nullptr) {
322 auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(res_tensor->device_address());
323 MS_LOG(DEBUG) << "res tensor_address:" << tensor_address;
324 MS_EXCEPTION_IF_NULL(tensor_address);
325 if (tensor_address->user_data() != nullptr) {
326 return tensor_address->user_data()->get<kernel::PyExecuteOutputUserData>(kernel::PyExecuteOutputUserData::key);
327 }
328 }
329 }
330 return nullptr;
331 }
332
333 py::object BaseRefToPyDataWithUserData(const BaseRef &value, const AbstractBasePtr &abs);
334
335 template <typename T>
GetVectorRefPyDataWithAbstract(const VectorRef & value_list,const abstract::AbstractSequencePtr & seq_abs)336 py::object GetVectorRefPyDataWithAbstract(const VectorRef &value_list, const abstract::AbstractSequencePtr &seq_abs) {
337 auto value_size = value_list.size();
338 auto ret = T(value_size);
339
340 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
341 size_t ref_idx = 0;
342 for (size_t i = 0; i < seq_abs->size(); ++i) {
343 auto elem_abs = seq_abs->elements()[i];
344 if (elem_abs->isa<abstract::AbstractNone>() && !allow_fallback_runtime) {
345 continue;
346 }
347 ret[ref_idx] = BaseRefToPyDataWithUserData(value_list[ref_idx], elem_abs);
348 ref_idx++;
349 }
350 if (ref_idx != value_size) {
351 MS_LOG(EXCEPTION) << "The size of elements (excluding None) should be equal to " << value_size << ", but got "
352 << ref_idx;
353 }
354 return ret;
355 }
356
GetVectorRefPyData(const VectorRef & value_list,const AbstractBasePtr & abs)357 py::object GetVectorRefPyData(const VectorRef &value_list, const AbstractBasePtr &abs) {
358 if (abs == nullptr || abs->isa<abstract::AbstractCSRTensor>() || abs->isa<abstract::AbstractCOOTensor>() ||
359 abs->isa<abstract::AbstractAny>()) {
360 return BaseRefToPyData(value_list, abs);
361 }
362 // Need to consider AbstractAny with vector ref scene later.
363 if (!abs->isa<abstract::AbstractSequence>()) {
364 MS_LOG(EXCEPTION) << "Can not convert vector ref with abstract " << abs->ToString();
365 }
366 auto seq_abs = abs->cast<abstract::AbstractSequencePtr>();
367 if (seq_abs->dynamic_len()) {
368 return BaseRefToPyData(value_list, abs);
369 }
370 if (seq_abs->isa<abstract::AbstractTuple>()) {
371 return GetVectorRefPyDataWithAbstract<py::tuple>(value_list, seq_abs);
372 }
373 return GetVectorRefPyDataWithAbstract<py::list>(value_list, seq_abs);
374 }
375
BaseRefToPyDataWithUserData(const BaseRef & value,const AbstractBasePtr & abs)376 py::object BaseRefToPyDataWithUserData(const BaseRef &value, const AbstractBasePtr &abs) {
377 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kGraphExecutorPy, runtime::ProfilerEvent::kOutputProcess,
378 "BaseRefToPyData");
379 const auto allow_fallback_runtime = (fallback::GetJitSyntaxLevel() >= kCompatible);
380 if (!allow_fallback_runtime) {
381 return BaseRefToPyData(value, abs);
382 }
383 if (utils::isa<ValuePtr>(value)) {
384 // Do not use abs as input to BaseRefToPyData, since the res need to be a tensor to get user data.
385 auto res = BaseRefToPyData(value);
386 MS_LOG(DEBUG) << "res: " << py::str(res);
387 const auto user_data = GetUserDataFromAddress(res);
388 if (user_data != nullptr) {
389 return user_data->obj;
390 } else {
391 MS_LOG(DEBUG) << "user data is empty";
392 }
393 } else if (utils::isa<VectorRef>(value)) {
394 auto vec_ref = utils::cast<VectorRef>(value);
395 return GetVectorRefPyData(vec_ref, abs);
396 }
397 return BaseRefToPyData(value, abs);
398 }
399
AddManager(const FuncGraphManagerPtr & manager,const ValuePtr & value)400 void AddManager(const FuncGraphManagerPtr &manager, const ValuePtr &value) {
401 MS_EXCEPTION_IF_NULL(value);
402 if (value->isa<FuncGraph>()) {
403 auto fg = value->cast<FuncGraphPtr>();
404 manager->AddFuncGraph(fg);
405 }
406 if (value->isa<ValueSequence>()) {
407 auto value_sequence = value->cast<ValueSequencePtr>();
408 for (const auto &elem : value_sequence->value()) {
409 AddManager(manager, elem);
410 }
411 }
412 if (value->isa<ValueDictionary>()) {
413 for (const auto &elem : value->cast<ValueDictionaryPtr>()->value()) {
414 AddManager(manager, elem.second);
415 }
416 }
417 }
418
AddManagerForFuncGraphArgs(const ResourcePtr & resource,const ValuePtrList & arguments)419 void AddManagerForFuncGraphArgs(const ResourcePtr &resource, const ValuePtrList &arguments) {
420 auto manager = resource->manager();
421 MS_EXCEPTION_IF_NULL(manager);
422 for (const auto &arg : arguments) {
423 AddManager(manager, arg);
424 }
425 }
426
ResetId(const ResourcePtr & resource)427 void ResetId(const ResourcePtr &resource) {
428 MS_EXCEPTION_IF_NULL(resource);
429 #ifdef ENABLE_DUMP_IR
430 auto context = MsContext::GetInstance();
431 MS_EXCEPTION_IF_NULL(context);
432 auto need_dump = common::GetCompileConfig("DUMP_VALIDATE_BEFORE_RESET_ID");
433 if (context->CanDump(kIntroductory) && need_dump == "1") {
434 FuncGraphPtr graph = resource->func_graph();
435 DumpIR("validate_before_reset_id.ir", graph, true, kWholeStack);
436 }
437 #endif
438 mindspore::id_generator::reset_id();
439 const auto &all_nodes = TopoSort(resource->func_graph()->get_return(), SuccDeeperSimple);
440 for (const auto &node : all_nodes) {
441 if (node != nullptr && node->isa<CNode>()) {
442 const auto &cnode = node->cast<CNodePtr>();
443 MS_EXCEPTION_IF_NULL(cnode);
444 cnode->set_fullname_with_scope("");
445 }
446 }
447 }
448
CheckShapeConsistency(const abstract::ShapePtr & compile_shape,const abstract::ShapePtr & args_shape,const std::string & target_str,size_t index)449 void CheckShapeConsistency(const abstract::ShapePtr &compile_shape, const abstract::ShapePtr &args_shape,
450 const std::string &target_str, size_t index) {
451 MS_EXCEPTION_IF_NULL(compile_shape);
452 MS_EXCEPTION_IF_NULL(args_shape);
453 if (*compile_shape == *args_shape) {
454 return;
455 }
456
457 auto compile_shape_vec = compile_shape->shape();
458 auto args_shape_vec = args_shape->shape();
459
460 if (!IsDynamicRank(compile_shape_vec)) {
461 if (!args_shape_vec.empty() && compile_shape_vec.size() != args_shape_vec.size()) {
462 MS_EXCEPTION(ValueError) << "For " << target_str << " and tuple(list) in " << target_str << ", the dims of "
463 << index + 1 << "th input must be the same as expected, "
464 << "but got expected: " << compile_shape_vec.size()
465 << ", and input: " << args_shape_vec.size() << "!";
466 }
467
468 for (size_t i = 0; i < compile_shape_vec.size(); ++i) {
469 if (compile_shape_vec[i] == abstract::Shape::kShapeDimAny || compile_shape_vec[i] == args_shape_vec[i]) {
470 continue;
471 }
472 MS_EXCEPTION(ValueError) << "For " << target_str << " and tuple(list) in " << target_str << ", the shape of "
473 << index + 1 << "th input must be the same as expected, "
474 << "but got expected: " << compile_shape_vec[i] << ", and input: " << args_shape_vec[i]
475 << "!";
476 }
477 }
478 }
479
CheckSizeConsistency(const AbstractBasePtrList & compile_abstracts,const AbstractBasePtrList & args_abstracts,const std::string & target_str,bool dynamic_len=false)480 inline void CheckSizeConsistency(const AbstractBasePtrList &compile_abstracts,
481 const AbstractBasePtrList &args_abstracts, const std::string &target_str,
482 bool dynamic_len = false) {
483 if (!dynamic_len && compile_abstracts.size() != args_abstracts.size()) {
484 MS_EXCEPTION(ValueError) << "For " << target_str << " and tuple(list) in " << target_str
485 << ", the length of input must be equal to expected one, but got expected: "
486 << compile_abstracts.size() << " and input: " << args_abstracts.size() << "!";
487 }
488 if (dynamic_len && compile_abstracts.empty()) {
489 MS_LOG(INTERNAL_EXCEPTION) << "For " << target_str << ", the dynamic_len compile arguments should not be empty!";
490 }
491 }
492
CheckAbstractConsistency(const AbstractBasePtrList & compile_abstracts,const AbstractBasePtrList & args_abstracts,const std::string & target_str,bool dynamic_len=false)493 void CheckAbstractConsistency(const AbstractBasePtrList &compile_abstracts, const AbstractBasePtrList &args_abstracts,
494 const std::string &target_str, bool dynamic_len = false) {
495 CheckSizeConsistency(compile_abstracts, args_abstracts, target_str, dynamic_len);
496 for (size_t i = 0; i < args_abstracts.size(); ++i) {
497 auto compile_abs = dynamic_len ? compile_abstracts[0] : compile_abstracts[i];
498 auto args_abs = args_abstracts[i];
499 auto is_compile_var = compile_abs->BuildValue()->ContainsValueAny();
500 auto is_args_var = args_abs->BuildValue()->ContainsValueAny();
501 if (is_compile_var != is_args_var) {
502 MS_EXCEPTION(TypeError) << "For " << target_str << " or tuple(list) in " << target_str << ", the " << i + 1
503 << "th should be " << (is_compile_var ? "mutable" : "static") << " one, but got "
504 << (is_args_var ? "mutable" : "static") << "!";
505 }
506
507 if (is_compile_var) {
508 if (compile_abs->isa<abstract::AbstractTensor>() && args_abs->isa<abstract::AbstractTensor>()) {
509 auto compile_tensor = compile_abs->cast<abstract::AbstractTensorPtr>();
510 auto args_tensor = args_abs->cast<abstract::AbstractTensorPtr>();
511
512 // Check shape's consistency.
513 auto compile_shape = compile_tensor->shape();
514 auto args_shape = args_tensor->shape();
515 CheckShapeConsistency(compile_shape, args_shape, target_str, i);
516
517 auto compile_element = compile_tensor->element();
518 auto args_element = args_tensor->element();
519 if (!common::IsEqual(compile_element, args_element)) {
520 MS_EXCEPTION(TypeError) << "For " << target_str << " or tuple(list) in " << target_str << ", the " << i + 1
521 << "th type should be " << compile_tensor->BuildType()->ToString() << ", but got "
522 << args_tensor->BuildType()->ToString() << "!";
523 }
524 } else if (compile_abs->isa<abstract::AbstractSequence>() && args_abs->isa<abstract::AbstractSequence>()) {
525 auto compile_sequence = compile_abs->cast<abstract::AbstractSequencePtr>();
526 auto args_sequence = args_abs->cast<abstract::AbstractSequencePtr>();
527 CheckAbstractConsistency(compile_sequence->elements(), args_sequence->elements(), target_str,
528 compile_sequence->dynamic_len());
529 } else {
530 if (!common::IsEqual(compile_abs, args_abs)) {
531 MS_EXCEPTION(ValueError) << "For " << target_str << " or tuple(list) in " << target_str << ", the " << i + 1
532 << "th should be" << compile_abs->ToString() << ", but got " << args_abs->ToString()
533 << "!";
534 }
535 }
536 } else if (compile_abs->isa<abstract::AbstractList>() && args_abs->isa<abstract::AbstractList>()) {
537 auto compile_sequence = compile_abs->cast<abstract::AbstractSequencePtr>();
538 auto args_sequence = args_abs->cast<abstract::AbstractSequencePtr>();
539 CheckAbstractConsistency(compile_sequence->elements(), args_sequence->elements(), target_str);
540 } else {
541 if (!common::IsEqual(compile_abs, args_abs)) {
542 MS_EXCEPTION(ValueError) << "For " << target_str << " or tuple(list) in " << target_str << ", the " << i + 1
543 << "th should be" << compile_abs->ToString() << ", but got " << args_abs->ToString()
544 << "!";
545 }
546 }
547 }
548 }
549 } // namespace
550
GetObjDesc(const py::object & source)551 std::string GetObjDesc(const py::object &source) {
552 std::string obj_desc;
553 if (py::hasattr(source, parse::PYTHON_PARSE_METHOD)) {
554 auto cell_class_name = source.attr("__class__").attr("__name__");
555 auto jit_name = source.attr(parse::PYTHON_PARSE_METHOD);
556 obj_desc = "'" + py::cast<std::string>(cell_class_name) + "." + py::cast<std::string>(jit_name) + "'";
557 } else {
558 if (py::hasattr(source, "__name__")) {
559 auto jit_name = source.attr("__name__");
560 obj_desc = "'" + py::cast<std::string>(jit_name) + "'";
561 } else if (py::isinstance<Cell>(source)) {
562 auto cell_class_name = source.attr("__class__").attr("__name__");
563 obj_desc = "'" + py::cast<std::string>(cell_class_name) + ".construct'";
564 } else {
565 MS_EXCEPTION(TypeError) << "The source object is invalid: " << py::str(source);
566 }
567 }
568 return obj_desc;
569 }
570
CheckArgsValid(const py::object & source,const py::tuple & args)571 void CheckArgsValid(const py::object &source, const py::tuple &args) {
572 if (!IS_OUTPUT_ON(mindspore::kInfo)) {
573 return;
574 }
575 for (size_t i = 0; i < args.size(); i++) {
576 if (!CheckArgValid(args[i])) {
577 MS_LOG(INFO) << "The " << ToOrdinal(i + 1) << " arg type is " << args[i].get_type() << ", value is '"
578 << py::str(args[i]) << "'.";
579 }
580 }
581 }
582
CheckArgumentsConsistency(const py::tuple & compile_args,const py::tuple & args_list,const py::object & target)583 void GraphExecutorPy::CheckArgumentsConsistency(const py::tuple &compile_args, const py::tuple &args_list,
584 const py::object &target) {
585 if ((!py::isinstance<py::str>(target))) {
586 MS_EXCEPTION(TypeError) << "The `target` must be string!";
587 }
588 std::string target_str = py::cast<std::string>(target);
589 if (compile_args.size() != args_list.size()) {
590 MS_EXCEPTION(ValueError) << "For " << target_str
591 << ", the length of input must be equal to expected one, but got expected: "
592 << compile_args.size() << " and input: " << args_list.size() << "!";
593 }
594
595 AbstractBasePtrList compile_abstracts;
596 compile_abstracts.reserve(compile_args.size());
597 AbstractBasePtrList args_abstracts;
598 args_abstracts.reserve(compile_args.size());
599 for (size_t i = 0; i < compile_args.size(); ++i) {
600 ValuePtr compile_args_converted = nullptr;
601 if (!parse::ConvertData(compile_args[i], &compile_args_converted)) {
602 MS_LOG(INTERNAL_EXCEPTION) << "ConvertData for " << i << "th compiling argument failed, the argument type is "
603 << compile_args[i].get_type() << ", value is '" << py::str(compile_args[i]) << "'.";
604 }
605 compile_abstracts.push_back(ArgsToAbstract(compile_args[i], compile_args_converted));
606
607 ValuePtr args_converted = nullptr;
608 if (!parse::ConvertData(args_list[i], &args_converted)) {
609 MS_LOG(INTERNAL_EXCEPTION) << "ConvertData for " << i << "th input argument failed, the argument type is "
610 << args_list[i].get_type() << ", value is '" << py::str(args_list[i]) << "'.";
611 }
612 args_abstracts.push_back(ArgsToAbstract(args_list[i], args_converted));
613 }
614
615 CheckAbstractConsistency(compile_abstracts, args_abstracts, target_str, false);
616 }
617
GenerateArgumentsKey(const py::object & obj,const py::tuple & args,const py::dict & kwargs,bool enable_tuple_broaden)618 py::object GraphExecutorPy::GenerateArgumentsKey(const py::object &obj, const py::tuple &args, const py::dict &kwargs,
619 bool enable_tuple_broaden) {
620 MS_LOG(DEBUG) << "GenerateArgumentsKey args size: " << args.size()
621 << ", enable_tuple_broaden: " << enable_tuple_broaden;
622
623 abstract::AbstractBasePtrList args_abs;
624 ClearCurConvertInput();
625 for (std::size_t i = 0; i < args.size(); i++) {
626 ValuePtr converted = nullptr;
627 if (!parse::ConvertData(args[i], &converted)) {
628 MS_LOG(INTERNAL_EXCEPTION) << "ConvertData for " << i << "th argument failed, the argument type is "
629 << args[i].get_type() << ", value is '" << py::str(args[i]) << "'.";
630 }
631 AbstractBasePtr abs = ArgsToAbstract(args[i], converted, enable_tuple_broaden);
632 (void)args_abs.emplace_back(abs);
633 // The 'converted' maybe a Parameter, we need connect it to the Parameter of func graph,
634 // so we keep all inputs for subsequent procedure.
635 (void)cur_convert_input_.emplace(args[i].ptr(), std::make_pair(converted, abs));
636 }
637 for (const auto &item : kwargs) {
638 ValuePtr key = nullptr;
639 ValuePtr value = nullptr;
640 bool success = parse::ConvertData(py::cast<py::object>(item.first), &key) &&
641 parse::ConvertData(py::cast<py::object>(item.second), &value);
642 if (!success) {
643 MS_LOG(INTERNAL_EXCEPTION) << "ConvertData for argument (" << py::str(item.first) << ": " << py::str(item.second)
644 << ") failed.";
645 }
646 AbstractBasePtr value_abs = ArgsToAbstract(py::cast<py::object>(item.second), value, enable_tuple_broaden);
647 auto keyword_arg_abs = std::make_shared<abstract::AbstractKeywordArg>(GetValue<std::string>(key), value_abs);
648 (void)args_abs.emplace_back(keyword_arg_abs);
649 (void)cur_convert_input_.emplace(item.first.ptr(), std::make_pair(value, keyword_arg_abs));
650 }
651
652 // If cache matched no need CheckArgsValid
653 auto iter = kArgsCache.find(args_abs);
654 if (iter != kArgsCache.end()) {
655 return py::int_(iter->second);
656 }
657
658 static uint64_t key_counter = 0;
659 kArgsCache[args_abs] = key_counter;
660 kCellArgsMap[obj.ptr()] = args_abs;
661 MS_LOG(INFO) << "Generate a new compile key for new args, key: " << key_counter;
662 if (IS_OUTPUT_ON(mindspore::kInfo)) {
663 std::ostringstream buffer;
664 buffer << "New cached args:"
665 << "\n";
666 for (size_t i = 0; i < args_abs.size(); ++i) {
667 buffer << "Arg[" << i << "]: " << args_abs[i]->ToString() << "\n";
668 }
669 MS_LOG(INFO) << buffer.str();
670 }
671 return py::int_(key_counter++);
672 }
673
ClearCompileArgumentsResource()674 void GraphExecutorPy::ClearCompileArgumentsResource() {
675 // Clear global converted args saved in GenerateArgumentsKey.
676 ClearCurConvertInput();
677 }
678
ClearArgCache(const py::object & obj)679 void ClearArgCache(const py::object &obj) {
680 if (py::isinstance<py::none>(obj)) {
681 return;
682 }
683 auto iter = kCellArgsMap.find(obj.ptr());
684 if (iter != kCellArgsMap.end()) {
685 (void)kArgsCache.erase(iter->second);
686 (void)kCellArgsMap.erase(iter);
687 }
688 }
689
ClearCurConvertInput()690 void GraphExecutorPy::ClearCurConvertInput() { cur_convert_input_.clear(); }
691
ParentBeforeFork()692 void GraphExecutorPy::ParentBeforeFork() {
693 MS_LOG(DEBUG) << "GraphExecutorPy prepare before fork.";
694 MS_LOG(DEBUG) << "Stop AnalysisSchedule tasks.";
695 abstract::AnalysisSchedule::GetInstance().Stop();
696 MS_LOG(DEBUG) << "GraphExecutorPy prepare before fork done.";
697 }
698
ParentAfterFork()699 void GraphExecutorPy::ParentAfterFork() {
700 MS_LOG(DEBUG) << "GraphExecutorPy in parent process reinitialize after fork.";
701 MS_LOG(DEBUG) << "Restart AnalysisSchedule tasks.";
702 abstract::AnalysisSchedule::GetInstance().Start();
703 MS_LOG(DEBUG) << "GraphExecutorPy in parent process reinitialize after fork done.";
704 }
705
ChildAfterFork()706 void GraphExecutorPy::ChildAfterFork() {
707 MS_LOG(DEBUG) << "GraphExecutorPy in child process reinitialize after fork.";
708 MS_LOG(DEBUG) << "Restart AnalysisSchedule tasks.";
709 abstract::AnalysisSchedule::GetInstance().Start();
710 MS_LOG(DEBUG) << "GraphExecutorPy in child process reinitialize after fork done.";
711 }
712
VerifyInputSignature(const py::list & input_signature,const py::tuple & inputs)713 py::bool_ VerifyInputSignature(const py::list &input_signature, const py::tuple &inputs) {
714 MS_LOG(DEBUG) << "Verify args size:" << inputs.size();
715 if (inputs.size() != input_signature.size()) {
716 MS_LOG(ERROR) << "Signature size not equal to args size";
717 return false;
718 }
719
720 size_t count = 0;
721 for (auto arg_obj : inputs) {
722 std::shared_ptr<Tensor> m_tensor = nullptr;
723 bool is_tensor = false;
724 if (py::isinstance<Tensor>(arg_obj)) {
725 m_tensor = arg_obj.cast<std::shared_ptr<Tensor>>();
726 is_tensor = true;
727 } else if (IsStubTensor(arg_obj)) {
728 m_tensor = ConvertStubTensor(arg_obj);
729 is_tensor = true;
730 }
731 if (is_tensor && m_tensor == nullptr) {
732 MS_LOG(ERROR) << "Verify Tensor error, get ptr is null";
733 return false;
734 }
735
736 if (m_tensor != nullptr) {
737 MS_LOG(DEBUG) << "Verify Tensor";
738 auto sig = input_signature[count].cast<std::shared_ptr<MetaTensor>>();
739 ShapeVector sig_shape = sig->shape();
740 TypePtr sig_type = sig->Dtype();
741
742 ShapeVector tensor_shape = m_tensor->shape_c();
743 if (tensor_shape != sig_shape) {
744 MS_LOG(ERROR) << "Python input shape is incompatible with input_signature";
745 return false;
746 }
747
748 if (*m_tensor->Dtype() != *sig_type) {
749 MS_LOG(ERROR) << "Python input type(" << m_tensor->Dtype()->ToString() << ") incompatible with input_signature("
750 << sig_type->ToString() << ")";
751 return false;
752 }
753 }
754 count++;
755 }
756
757 return true;
758 }
759
GetResource(const std::string & phase)760 ResourcePtr GraphExecutorPy::GetResource(const std::string &phase) {
761 MS_LOG(DEBUG) << "Phase size:" << info_.size();
762 if (info_.count(phase) == 0) {
763 return nullptr;
764 }
765 return info_[phase]->resource;
766 }
767
GetFuncGraph(const std::string & phase)768 FuncGraphPtr GraphExecutorPy::GetFuncGraph(const std::string &phase) {
769 const auto it = info_.find(phase);
770 if (it == info_.end()) {
771 MS_LOG(INFO) << "No executor info. found for phase: " << phase;
772 return nullptr;
773 }
774 return it->second->func_graph;
775 }
776
SetJitPrimalFuncGraph(const FuncGraphPtr & primal_func_graph,const std::string & phase)777 void GraphExecutorPy::SetJitPrimalFuncGraph(const FuncGraphPtr &primal_func_graph, const std::string &phase) {
778 const auto it = info_.find(phase);
779 if (it == info_.end()) {
780 MS_LOG(INTERNAL_EXCEPTION) << "No executor info. found for phase: " << phase;
781 return;
782 }
783 MS_EXCEPTION_IF_NULL(primal_func_graph);
784 it->second->jit_primal_func_graph = primal_func_graph;
785 }
786
GetJitPrimalFuncGraph(const std::string & phase)787 FuncGraphPtr GraphExecutorPy::GetJitPrimalFuncGraph(const std::string &phase) {
788 const auto it = info_.find(phase);
789 if (it == info_.end()) {
790 MS_LOG(INFO) << "No executor info. found for phase: " << phase;
791 return nullptr;
792 }
793 return it->second->jit_primal_func_graph;
794 }
795
GetJitGradGraph(const std::string & phase)796 FuncGraphPtr GraphExecutorPy::GetJitGradGraph(const std::string &phase) {
797 const auto it = info_.find(phase);
798 if (it == info_.end()) {
799 MS_LOG(INFO) << "No executor info. found for phase: " << phase;
800 return nullptr;
801 }
802 return it->second->jit_grad_graph;
803 }
804
SetJitGradGraph(const FuncGraphPtr & grad_graph,const std::string & phase)805 void GraphExecutorPy::SetJitGradGraph(const FuncGraphPtr &grad_graph, const std::string &phase) {
806 const auto it = info_.find(phase);
807 if (it == info_.end()) {
808 MS_LOG(INTERNAL_EXCEPTION) << "No executor info. found for phase: " << phase;
809 return;
810 }
811 if (it->second->jit_grad_graph != nullptr) {
812 MS_LOG(DEBUG) << "The grad graph has existed, phase is: " << phase;
813 }
814 MS_EXCEPTION_IF_NULL(grad_graph);
815 it->second->jit_grad_graph = grad_graph;
816 }
817
GetVmEvalFunc(const std::string & phase)818 compile::VmEvalFuncPtr GraphExecutorPy::GetVmEvalFunc(const std::string &phase) {
819 ResourcePtr res = GetResource(phase);
820 MS_EXCEPTION_IF_NULL(res);
821 if (res->HasResult(kOutput) && res->GetResult(kOutput).is<compile::VmEvalFuncPtr>()) {
822 return res->GetResult(kOutput).cast<compile::VmEvalFuncPtr>();
823 }
824 MS_LOG(ERROR) << "GetVmEvalFunc vm model can't find kOutput:" << kOutput;
825 return nullptr;
826 }
827
HasCompiled(const std::string & phase) const828 bool GraphExecutorPy::HasCompiled(const std::string &phase) const { return info_.count(phase) != 0; }
829
GetFuncGraphProto(const std::string & phase,const std::string & ir_type,const bool & incremental)830 py::bytes GraphExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type,
831 const bool &incremental) {
832 FuncGraphPtr fg_ptr = GetFuncGraph(phase);
833 if (fg_ptr == nullptr) {
834 for (const auto &item : info_) {
835 MS_LOG(DEBUG) << "Phase key is: " << item.first;
836 }
837 MS_LOG(EXCEPTION) << "Can not find func graph " << phase;
838 }
839
840 if (ir_type == IR_TYPE_ANF) {
841 std::string proto_str = GetFuncGraphProtoString(fg_ptr);
842 if (proto_str.empty()) {
843 MS_LOG(EXCEPTION) << "Export ANF format model failed.";
844 }
845 return proto_str;
846 }
847
848 if (ir_type == IR_TYPE_ONNX) {
849 std::string proto_str = GetOnnxProtoString(fg_ptr);
850 if (proto_str.empty()) {
851 MS_LOG(EXCEPTION) << "Export ONNX format model failed.";
852 }
853 return proto_str;
854 }
855
856 if (ir_type == IR_TYPE_MINDIR) {
857 // obfuscate model
858 std::string proto_str = GetBinaryProtoString(fg_ptr, incremental);
859 if (proto_str.empty()) {
860 MS_LOG(EXCEPTION) << "Export MINDIR format model failed.";
861 }
862 return proto_str;
863 }
864
865 MS_LOG(INTERNAL_EXCEPTION) << "Unknown ir type: " << ir_type;
866 }
867
GetObfuscateFuncGraphProto(const std::string & phase,const bool & incremental,const float obf_ratio,const int branch_control_input)868 py::bytes GraphExecutorPy::GetObfuscateFuncGraphProto(const std::string &phase, const bool &incremental,
869 const float obf_ratio, const int branch_control_input) {
870 FuncGraphPtr fg_ptr = GetFuncGraph(phase);
871 // obfuscate model
872 if (branch_control_input == 0) {
873 (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
874 MS_LOG(DEBUG) << "[GetObfuscateFuncGraphProto] set customized function names finished";
875 }
876 mindspore::DynamicObfuscator dynamic_obfuscator(obf_ratio, branch_control_input);
877 mindspore::FuncGraphPtr obfuscated_graph = dynamic_obfuscator.ObfuscateMindIR(fg_ptr);
878
879 std::string proto_str = GetBinaryProtoString(obfuscated_graph, incremental);
880 if (proto_str.empty()) {
881 MS_LOG(EXCEPTION) << "GetBinaryProtoString failed.";
882 }
883 return proto_str;
884 }
885
GetOptimizeGraphProto(const std::string & phase)886 py::bytes GraphExecutorPy::GetOptimizeGraphProto(const std::string &phase) {
887 if (info_.count(phase) == 0) {
888 MS_LOG(INTERNAL_EXCEPTION) << "No phase in executor: " << phase;
889 }
890 FuncGraphPtr fg_ptr = info_[phase]->resource->optimize_graph();
891 if (fg_ptr == nullptr) {
892 MS_LOG(WARNING) << "Can not find optimize graph.";
893 return "";
894 }
895 std::string proto_str = GetFuncGraphProtoString(fg_ptr);
896 if (proto_str.empty()) {
897 MS_LOG(EXCEPTION) << "Export optimize graph proto string failed.";
898 }
899 return proto_str;
900 }
901
SetJitConfig(const py::dict & config)902 void GraphExecutorPy::SetJitConfig(const py::dict &config) {
903 auto jit_config = GenerateJitConfigMap(config);
904 PhaseManager::GetInstance().set_jit_config(jit_config);
905 }
906
GetParallelGraphInfo(const std::string & phase)907 py::dict GraphExecutorPy::GetParallelGraphInfo(const std::string &phase) {
908 MS_LOG(DEBUG) << "GetParallelGraphInfo!";
909 std::string parallel_phase = phase + kStepParallelGraph;
910 auto graph = GetFuncGraph(parallel_phase);
911 if (graph == nullptr) {
912 MS_LOG(INTERNAL_EXCEPTION) << "Can not access FuncGraph according to phase: " << parallel_phase;
913 }
914
915 return mindspore::parallel::GetParallelCNodeInfoFromGraph(graph);
916 }
917
GetParameterLayout(const std::string & phase)918 py::dict GraphExecutorPy::GetParameterLayout(const std::string &phase) {
919 MS_LOG(DEBUG) << "GetParameterLayout!";
920 std::string layout_graph = phase + kStepParallelGraph;
921 auto graph = GetFuncGraph(layout_graph);
922 if (graph == nullptr) {
923 auto resource = info_[phase]->resource;
924 return mindspore::parallel::GetParameterLayoutFromResource(resource);
925 }
926 return mindspore::parallel::GetParameterLayoutFromGraph(graph);
927 }
928
FlopsCollection(const std::string & phase)929 py::tuple GraphExecutorPy::FlopsCollection(const std::string &phase) {
930 auto graph = GetFuncGraph(phase);
931 return mindspore::parallel::FlopsCollection(graph);
932 }
933
GetCNodeStrategy(const std::string & phase)934 py::dict GraphExecutorPy::GetCNodeStrategy(const std::string &phase) {
935 MS_LOG(DEBUG) << "GetCNodeStrategy!";
936 return stra_dict_[phase];
937 }
938
GetParallelParameterNameList(const std::string & phase)939 py::list GraphExecutorPy::GetParallelParameterNameList(const std::string &phase) {
940 std::string param_graph = phase + kStepParallelGraph;
941 auto graph = GetFuncGraph(param_graph);
942 if (graph == nullptr) {
943 auto resource = info_[phase]->resource;
944 return mindspore::parallel::GetParallelParameterNameListFromResource(resource);
945 }
946 return mindspore::parallel::GetParallelParameterNameListFromGraph(graph);
947 }
948
SetCNodeStrategy(const std::string & name,const parallel::Strategies & strategy)949 void GraphExecutorPy::SetCNodeStrategy(const std::string &name, const parallel::Strategies &strategy) {
950 MS_LOG(DEBUG) << "SetCNodeStrategy!";
951 stra_dict_[phase_][py::str(name)] = strategy;
952 }
953
GetNumOpsInfo(const std::string & phase)954 size_t GraphExecutorPy::GetNumOpsInfo(const std::string &phase) {
955 MS_LOG(DEBUG) << "GetNumOpsInfo!";
956 return phase_to_num_op_info_[phase];
957 }
958
SetNumOpsInfo(size_t num_ops)959 void GraphExecutorPy::SetNumOpsInfo(size_t num_ops) {
960 MS_LOG(DEBUG) << "SetNumOpsInfo!";
961 phase_to_num_op_info_[phase_] = num_ops;
962 }
963
GetAllreduceFusion(const std::string & phase)964 py::dict GraphExecutorPy::GetAllreduceFusion(const std::string &phase) {
965 MS_LOG(INFO) << "GetAllreduceFusion!";
966 auto graph = GetFuncGraph(phase);
967 return mindspore::parallel::GetAllreduceFusion(graph);
968 }
969
970 // Not support multi thread, not support nested call too.
971 // Here using nested_called flg to avoid nested call.
DelNetRes(const py::object & source,const py::set & id)972 void GraphExecutorPy::DelNetRes(const py::object &source, const py::set &id) {
973 ClearArgCache(source);
974 // Del all graphs by different phase
975 for (auto item : id) {
976 DelOneNetRes(item);
977 }
978 }
979
DelOneNetRes(const py::handle & py_phase)980 void GraphExecutorPy::DelOneNetRes(const py::handle &py_phase) {
981 if (!pybind11::isinstance<py::str>(py_phase)) {
982 MS_LOG(ERROR) << "Expect string phase, but got " << py::str(py_phase);
983 return;
984 }
985 auto phase = pybind11::cast<std::string>(py_phase);
986 MS_LOG(INFO) << "Delete one net resource start, phase: " << phase;
987 auto iter = info_.find(phase);
988 auto clear = false;
989 if (iter != info_.end()) {
990 clear = true;
991 auto res = iter->second->resource;
992 if (res->HasResult(kStepParallelGraph)) {
993 std::string layout_graph = phase + kStepParallelGraph;
994 (void)info_.erase(layout_graph);
995 }
996 (void)info_.erase(phase);
997 MS_LOG(DEBUG) << "Delete phase: " << phase << ", info size: " << info_.size();
998 }
999 if (clear) {
1000 // Do clear here to avoid any pointer for resource.
1001 FuncGraphLoopBreaker::Inst().ClearCellGraphs(phase);
1002 FuncGraphLoopBreaker::Inst().CleanUnusedFuncGraphs(phase);
1003 }
1004 MS_LOG(INFO) << "Delete one net resource end. " << clear;
1005 }
1006
ClearRes()1007 void GraphExecutorPy::ClearRes() {
1008 MS_LOG(INFO) << "Clean executor resource!";
1009 executor_ = nullptr;
1010 }
1011
get_queue_name(const std::string & dataset_phase)1012 std::string GraphExecutorPy::get_queue_name(const std::string &dataset_phase) {
1013 return CompileCacheManager::GetCachedDataQueueName(dataset_phase);
1014 }
1015
~GraphExecutorPy()1016 GraphExecutorPy::~GraphExecutorPy() {
1017 MS_LOG(INFO) << "Release Executor!";
1018 ConfigManager::GetInstance().ResetConfig();
1019 }
1020
SaveCompiledGraph(const std::string & phase)1021 void GraphExecutorPy::SaveCompiledGraph(const std::string &phase) {
1022 // save the graph to GraphExecutorPy
1023 FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
1024 MS_EXCEPTION_IF_NULL(func_graph);
1025 MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase << ")!";
1026 info_[phase]->func_graph = func_graph;
1027 func_graph->set_attr("phase", MakeValue(GetPhasePrefix(phase)));
1028
1029 if ((func_graph != nullptr) && parallel::IsAutoParallelCareGraph(func_graph)) {
1030 MS_LOG(DEBUG) << "Save model parallel parameter layout graph!";
1031 auto res = info_[phase]->resource;
1032 // When using frontend compile cache, model parallel parameter layout graph is not saved.
1033 if (res->HasResult(kStepParallelGraph)) {
1034 func_graph = res->GetResult(kStepParallelGraph).cast<FuncGraphPtr>();
1035 ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
1036 std::string layout_graph = phase + kStepParallelGraph;
1037 executor_info->func_graph = func_graph;
1038 info_[layout_graph] = executor_info;
1039 }
1040 } else {
1041 MS_LOG(DEBUG) << "Save model parallel parameter layout graph null!";
1042 }
1043 MS_LOG(INFO) << "End save compiled func graph!";
1044 }
1045
GetGeBackendPolicy() const1046 void GraphExecutorPy::GetGeBackendPolicy() const {
1047 auto ms_context = MsContext::GetInstance();
1048 MS_EXCEPTION_IF_NULL(ms_context);
1049 std::string backend = ms_context->backend_policy();
1050 if (backend != "ge") {
1051 MS_LOG(EXCEPTION) << backend << " backend policy is not supported under ge backend!";
1052 }
1053 }
1054
IsPhaseExportAir(const std::string & phase)1055 bool IsPhaseExportAir(const std::string &phase) {
1056 auto phase_to_export = "export.air";
1057 return phase.rfind(phase_to_export) != std::string::npos;
1058 }
1059
IsPhaseExport(const std::string & phase)1060 bool IsPhaseExport(const std::string &phase) {
1061 constexpr auto export_str = "export";
1062 return phase.compare(0, strlen(export_str), export_str) == 0;
1063 }
1064
IsPhaseTrain(const std::string & phase)1065 bool IsPhaseTrain(const std::string &phase) {
1066 const std::string phase_to_train = "train";
1067 return phase.rfind(phase_to_train) != std::string::npos;
1068 }
1069
IsPhaseLoadFromMindIR(const std::string & phase)1070 bool IsPhaseLoadFromMindIR(const std::string &phase) {
1071 const std::string mindir_graph = "graph_load_from_mindir";
1072 return phase.rfind(mindir_graph) != std::string::npos;
1073 }
1074
GetActions(const ResourcePtr & resource,const std::string & phase,bool use_vm,bool trace_flag=false,bool erase_parse=false)1075 std::vector<ActionItem> GetActions(const ResourcePtr &resource, const std::string &phase, bool use_vm,
1076 bool trace_flag = false, bool erase_parse = false) {
1077 MS_EXCEPTION_IF_NULL(resource);
1078 compile::SetMindRTEnable();
1079 return VmPipeline(resource, trace_flag, erase_parse);
1080 }
1081
InitCompileCacheInfo(const ResourcePtr & resource,const std::string & phase)1082 void GraphExecutorPy::InitCompileCacheInfo(const ResourcePtr &resource, const std::string &phase) {
1083 // The compilation cache only support for training cell or functions decorated with 'jit' currently.
1084 // If enable compilation cache, it will get a non-empty dependent files list from python.
1085 if (!CompileCacheEnable()) {
1086 return;
1087 }
1088 bool has_python_script = true;
1089 if (compile_cache_dep_files_.empty()) {
1090 has_python_script = false;
1091 }
1092
1093 {
1094 MsProfileStatGuard stat_guard("LoadCachedFuncGraph");
1095 static size_t idx = 0;
1096 MS_EXCEPTION_IF_NULL(resource);
1097 resource->GetCompileCacheResource(compile_cache_dep_files_, weights_, queue_name_, idx++,
1098 &compile_cache_consistent_, has_python_script);
1099 }
1100 }
1101
ParallelPostProcess(const std::string & phase,bool use_compile_cache)1102 void GraphExecutorPy::ParallelPostProcess(const std::string &phase, bool use_compile_cache) {
1103 // Slice Python parameter obj
1104 auto layout_graph = phase + kStepParallelGraph;
1105 // only Parallel graph has tensor_layout
1106 auto root = GetFuncGraph(layout_graph);
1107 bool after_shard = false;
1108 if (phase.find("after_shard") != std::string::npos) {
1109 after_shard = true;
1110 }
1111 // Use compile cache
1112 if (use_compile_cache) {
1113 parallel::InitCompileCacheParams(info_[phase]->resource);
1114 return;
1115 }
1116 // Initialize parameters for graph which auto-parallel not care.
1117 if (root == nullptr && !after_shard) {
1118 auto graph = info_[phase]->resource->func_graph();
1119 MS_EXCEPTION_IF_NULL(graph);
1120 parallel::InitPynativeNoShardParams(graph);
1121 return;
1122 }
1123 MS_EXCEPTION_IF_NULL(root);
1124 parallel::AutoParallelPostProcess(root);
1125 }
1126
1127 // Clean all resource not used in the future and cache generated during compiling.
CleanCompileRes(const ResourcePtr & resource)1128 void GraphExecutorPy::CleanCompileRes(const ResourcePtr &resource) {
1129 MS_LOG(INFO) << "Clean compile resource start";
1130 ProcessStatus::GetInstance().RecordStart(kPipelineClean);
1131 (void)profiler::CollectHostInfo(kCompiler, kPipelineClean, kPipelineClean, 0, 0, 0);
1132 abstract::AnalysisContext::ClearContext();
1133 ClearCompileArgumentsResource();
1134 ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
1135 ad::g_k_prims.clear();
1136 ad::DFunctor::Clear();
1137 ReclaimOptimizer();
1138 resource->Clean();
1139 auto parallel_context = parallel::ParallelContext::GetInstance();
1140 MS_EXCEPTION_IF_NULL(parallel_context);
1141 if (parallel_context->hccl_test_available()) {
1142 parallel::g_device_manager = nullptr;
1143 }
1144 FuncGraphLoopBreaker::Inst().CleanMetaFuncGraphs();
1145 (void)profiler::CollectHostInfo(kCompiler, kPipelineClean, kPipelineClean, 0, 0, 1);
1146 ProcessStatus::GetInstance().RecordEnd();
1147 CompileCacheContext::GetInstance().Clear();
1148 parse::Parser::CleanParserResource();
1149 MS_LOG(INFO) << "Clean compile resource end";
1150 }
1151
CompileInner(const FuncGraphPtr & graph,const py::tuple & args,const py::dict & kwargs,const std::string & phase,bool use_vm,bool trace_flag)1152 bool GraphExecutorPy::CompileInner(const FuncGraphPtr &graph, const py::tuple &args, const py::dict &kwargs,
1153 const std::string &phase, bool use_vm, bool trace_flag) {
1154 auto ms_context = MsContext::GetInstance();
1155 MS_EXCEPTION_IF_NULL(ms_context);
1156 ms_context->SetCellReuseLevel(CellReuseLevel::kNoCellReuse);
1157 PhaseManager::GetInstance().set_phase(phase);
1158 phase_ = phase;
1159
1160 ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
1161 ResourcePtr resource = std::make_shared<Resource>();
1162 resource->set_func_graph(graph);
1163 InitCompileCacheInfo(resource, phase);
1164 bool use_compile_cache = resource->EnableCompileCache() && resource->func_graph();
1165 ConfigManager::GetInstance().ResetQueue(queue_name_);
1166
1167 bool erase_parse = true;
1168 auto actions = GetActions(resource, phase, use_vm, trace_flag, erase_parse);
1169 std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, actions);
1170
1171 if (pip->NeedCreateBackend()) {
1172 // Create backend asynchronously.
1173 resource->SetBackendAsync([]() {
1174 auto backend = compile::CreateBackend();
1175 #ifdef ENABLE_DEBUGGER
1176 // Connect session to debugger.
1177 backend->SetDebugger();
1178 #endif
1179 return backend;
1180 });
1181 }
1182
1183 // Get the parameters items and add the value to args_abs.
1184 abstract::AbstractBasePtrList args_abs;
1185 std::vector<ValuePtr> arguments;
1186 MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
1187 bool is_auto_parallel = (parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kSemiAutoParallel ||
1188 parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kAutoParallel);
1189 ConvertArgs(args, kwargs, is_auto_parallel, &args_abs, &arguments);
1190 ConvertSymbolicShape(args, &args_abs);
1191 AddManagerForFuncGraphArgs(resource, arguments);
1192 resource->set_arguments(arguments);
1193 resource->set_args_abs(args_abs);
1194 executor_info->arg_list_size = args.size() + kwargs.size();
1195 executor_info->resource = resource;
1196 info_[phase] = executor_info;
1197 pip->Run();
1198
1199 // Save the compiled graph to MsPipeLine.
1200 SaveCompiledGraph(phase);
1201 if (is_auto_parallel) {
1202 ParallelPostProcess(phase, use_compile_cache);
1203 }
1204 #ifdef ENABLE_DUMP_IR
1205 mindspore::RDR::Snapshot();
1206 #endif
1207 CleanCompileRes(resource);
1208 PhaseManager::GetInstance().ClearPhase();
1209 MS_LOG(INFO) << "Finish compiling.";
1210 return true;
1211 }
1212
CompileInner(const py::object & source,const py::tuple & args,const py::dict & kwargs,const py::object & phase,bool use_vm)1213 bool GraphExecutorPy::CompileInner(const py::object &source, const py::tuple &args, const py::dict &kwargs,
1214 const py::object &phase, bool use_vm) {
1215 auto ms_context = MsContext::GetInstance();
1216 MS_EXCEPTION_IF_NULL(ms_context);
1217 ms_context->SetCellReuseLevel(CellReuseLevel::kNoCellReuse);
1218 // Check if the phase is valid.
1219 if ((!py::isinstance<py::str>(phase))) {
1220 MS_LOG(ERROR) << "The `phase` must be string.";
1221 return false;
1222 }
1223 // Check if the function or net is valid.
1224 if (py::isinstance<py::none>(source)) {
1225 MS_LOG(ERROR) << "The source object to compile should not be None.";
1226 return false;
1227 }
1228 // Check if the args of function or net is valid.
1229 CheckArgsValid(source, args);
1230
1231 source_ = py::cast<std::string>(py::str(source));
1232 phase_ = py::cast<std::string>(phase);
1233 PhaseManager::GetInstance().set_phase(phase_);
1234 obj_desc_ = GetObjDesc(source);
1235 MS_LOG(INFO) << "Start compiling, phase: " << phase_;
1236 PROF_START(compile_graph);
1237 MS_LOG(DEBUG) << "source: {" << source_ << "}\nargs: " << py::str(const_cast<py::tuple &>(args))
1238 << "\nkwargs: " << py::str(const_cast<py::dict &>(kwargs));
1239 EventMessage::PrintCompileStartMsg(phase_, obj_desc_);
1240
1241 ExecutorInfoPtr executor_info = std::make_shared<ExecutorInfo>();
1242 ResourcePtr resource = std::make_shared<Resource>(source);
1243 InitCompileCacheInfo(resource, phase_);
1244 bool enable_compile_cache = resource->EnableCompileCache();
1245 bool use_compile_cache = enable_compile_cache && resource->func_graph();
1246 ConfigManager::GetInstance().ResetQueue(queue_name_);
1247 auto &compile_cache_context = CompileCacheContext::GetInstance();
1248 compile_cache_context.SetUseCompileCache(use_compile_cache);
1249
1250 auto actions = GetActions(resource, phase_, use_vm, false, false);
1251 std::shared_ptr<Pipeline> pip = std::make_shared<Pipeline>(resource, actions);
1252
1253 (void)profiler::CollectHostInfo(kCompiler, kCreateBackend, kCreateBackend, 0, 0, 0);
1254 if (pip->NeedCreateBackend()) {
1255 // Create backend asynchronously.
1256 resource->SetBackendAsync([]() {
1257 auto backend = compile::CreateBackend();
1258 #ifdef ENABLE_DEBUGGER
1259 // Connect session to debugger.
1260 backend->SetDebugger();
1261 #endif
1262 return backend;
1263 });
1264 }
1265 (void)profiler::CollectHostInfo(kCompiler, kCreateBackend, kCreateBackend, 0, 0, 1);
1266
1267 // Get the parameters items and add the value to args_abs.
1268 abstract::AbstractBasePtrList args_abs;
1269 std::vector<ValuePtr> arguments;
1270 MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance());
1271 bool is_parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kSemiAutoParallel ||
1272 parallel::ParallelContext::GetInstance()->parallel_mode() == parallel::kAutoParallel;
1273 bool is_auto_parallel = is_parallel_mode && !py::hasattr(source, parallel::kSkipAutoParallelCompile) &&
1274 !py::hasattr(source, parallel::kKeepInputUnchanged);
1275 ConvertArgs(args, kwargs, is_auto_parallel, &args_abs, &arguments);
1276 ConvertSymbolicShape(args, &args_abs);
1277 AddManagerForFuncGraphArgs(resource, arguments);
1278 resource->set_arguments(arguments);
1279 resource->set_args_abs(args_abs);
1280 executor_info->arg_list_size = args.size() + kwargs.size();
1281 executor_info->resource = resource;
1282 info_[phase_] = executor_info;
1283 pip->Run();
1284
1285 // Save the compiled graph to MsPipeLine.
1286 SaveCompiledGraph(phase_);
1287 if (is_parallel_mode) {
1288 ParallelPostProcess(phase_, use_compile_cache);
1289 }
1290 #ifdef ENABLE_DUMP_IR
1291 mindspore::RDR::Snapshot();
1292 #endif
1293 CleanCompileRes(resource);
1294 EventMessage::PrintCompileEndMsg(phase_, obj_desc_);
1295 PhaseManager::GetInstance().ClearPhase();
1296 MS_LOG(INFO) << "Finish compiling.";
1297 PROF_END(compile_graph);
1298 return true;
1299 }
1300
ConvertArgs(const py::tuple & args,const py::dict & kwargs,bool is_auto_parallel,abstract::AbstractBasePtrList * args_abs,std::vector<ValuePtr> * arguments)1301 void GraphExecutorPy::ConvertArgs(const py::tuple &args, const py::dict &kwargs, bool is_auto_parallel,
1302 abstract::AbstractBasePtrList *args_abs, std::vector<ValuePtr> *arguments) {
1303 MS_EXCEPTION_IF_NULL(args_abs);
1304 MS_EXCEPTION_IF_NULL(arguments);
1305 for (std::size_t i = 0; i < args.size(); i++) {
1306 // In some parallel mode need full_tensor which cause the args of GenerateArgumentsKey not same to compile,
1307 // So can't use cur_convert_input_ directly.
1308 auto iter = cur_convert_input_.find(args[i].ptr());
1309 if (iter != cur_convert_input_.end()) {
1310 (void)arguments->emplace_back(iter->second.first);
1311 if (is_auto_parallel) {
1312 auto abs_item = iter->second.second->Clone();
1313 (void)parallel::ExtendInputArgsAbstractShape(abs_item, i);
1314 (void)args_abs->emplace_back(abs_item);
1315 continue;
1316 }
1317 (void)args_abs->emplace_back(iter->second.second);
1318 continue;
1319 }
1320 ValuePtr converted = nullptr;
1321 bool success = parse::ConvertData(args[i], &converted);
1322 if (!success) {
1323 MS_LOG(INTERNAL_EXCEPTION) << "Fail to convert the " << i << "th argument, args[" << i
1324 << "]: " << py::str(args[i]);
1325 }
1326 (void)arguments->emplace_back(converted);
1327 auto args_abstract_item = ArgsToAbstract(args[i], converted, enable_tuple_broaden_);
1328 if (is_auto_parallel) {
1329 (void)parallel::ExtendInputArgsAbstractShape(args_abstract_item, i);
1330 }
1331 (void)args_abs->emplace_back(args_abstract_item);
1332 }
1333 for (const auto &item : kwargs) {
1334 auto iter = cur_convert_input_.find(item.first.ptr());
1335 if (iter != cur_convert_input_.end()) {
1336 (void)arguments->emplace_back(iter->second.first);
1337 (void)args_abs->emplace_back(iter->second.second);
1338 continue;
1339 }
1340 ValuePtr key = nullptr;
1341 ValuePtr value = nullptr;
1342 bool success = parse::ConvertData(py::cast<py::object>(item.first), &key) &&
1343 parse::ConvertData(py::cast<py::object>(item.second), &value);
1344 if (!success) {
1345 MS_LOG(INTERNAL_EXCEPTION) << "Fail to convert the argument (" << py::str(item.first) << ": "
1346 << py::str(item.second) << ").";
1347 }
1348 AbstractBasePtr value_abs = ArgsToAbstract(py::cast<py::object>(item.second), value, enable_tuple_broaden_);
1349 auto keyword_arg_abs = std::make_shared<abstract::AbstractKeywordArg>(GetValue<std::string>(key), value_abs);
1350 (void)arguments->emplace_back(value);
1351 (void)args_abs->emplace_back(keyword_arg_abs);
1352 }
1353 }
1354
ConvertSymbolicShape(const py::tuple & args,AbstractBasePtrList * args_abs)1355 void GraphExecutorPy::ConvertSymbolicShape(const py::tuple &args, AbstractBasePtrList *args_abs) {
1356 std::vector<symshape::SymbolInfoList> symbol_infos;
1357 symbol_infos.reserve(args_abs->size());
1358 bool has_dyn_shape = false;
1359 bool is_parallel = parallel::IsSemiOrAutoParallelMode();
1360
1361 for (size_t i = 0; i < args.size(); i++) {
1362 auto iter = cur_convert_input_.find(args[i].ptr());
1363 if (iter == cur_convert_input_.end()) {
1364 continue;
1365 }
1366 auto &info_list = symbol_infos.emplace_back(symshape::SymbolInfoList{});
1367 if (!iter->second.first->isa<MetaTensor>()) {
1368 continue;
1369 }
1370 auto digital_shape = iter->second.second->GetShape();
1371 if (digital_shape->IsDynamic()) {
1372 has_dyn_shape = true;
1373 }
1374 constexpr char symbolic_shape_attr[] = "symbolic_shape";
1375 if (!py::hasattr(args[i], symbolic_shape_attr)) {
1376 if (is_parallel) {
1377 if (digital_shape != nullptr && digital_shape->isa<abstract::TensorShape>()) {
1378 info_list.resize(digital_shape->GetShapeVector().size());
1379 }
1380 }
1381 continue;
1382 }
1383 auto symbolic_shape_obj = py::getattr(args[i], symbolic_shape_attr);
1384 MS_EXCEPTION_IF_CHECK_FAIL(py::isinstance<py::list>(symbolic_shape_obj), "tensor.symbolic_shape should be a list");
1385 auto obj_list = py::cast<py::list>(symbolic_shape_obj);
1386 info_list.resize(obj_list.size());
1387 for (size_t j = 0; j < obj_list.size(); j++) {
1388 if (!py::isinstance<py::dict>(obj_list[j])) {
1389 continue;
1390 }
1391 auto dict_obj = py::cast<py::dict>(obj_list[j]);
1392 for (auto cfg_iter = dict_obj.begin(); cfg_iter != dict_obj.end(); ++cfg_iter) {
1393 auto cfg_key = py::cast<std::string>(cfg_iter->first);
1394 if (cfg_key == "max") {
1395 info_list[j].max = py::cast<int64_t>(cfg_iter->second);
1396 } else if (cfg_key == "min") {
1397 info_list[j].min = py::cast<int64_t>(cfg_iter->second);
1398 } else if (cfg_key == "divisor") {
1399 info_list[j].divisor = py::cast<int64_t>(cfg_iter->second);
1400 } else if (cfg_key == "remainder") {
1401 info_list[j].remainder = py::cast<int64_t>(cfg_iter->second);
1402 } else if (cfg_key == "id") {
1403 info_list[j].id = py::cast<int64_t>(cfg_iter->second);
1404 } else if (cfg_key == "name") {
1405 info_list[j].name = py::cast<std::string>(cfg_iter->second);
1406 }
1407 }
1408 }
1409 }
1410
1411 MS_LOG(DEBUG) << "before parallel symbol";
1412 parallel::PrintSymbolInfo(symbol_infos);
1413 symbol_infos = parallel::ParallelSymbolInfo(symbol_infos, has_dyn_shape);
1414 MS_LOG(DEBUG) << "after parallel symbol";
1415 parallel::PrintSymbolInfo(symbol_infos);
1416
1417 auto symbolic_shape_list = symshape::BuildSymbolicShapeBySymbolInfo(*args_abs, symbol_infos);
1418 for (size_t i = 0; i < symbolic_shape_list.size(); i++) {
1419 // when the same tensor object is used in set_inputs interface, the inputs may shared a same Abstract object.
1420 // but for dynamic shape, the same "-1" in abstract can be different symbolic shape.
1421 auto abs = symshape::CloneAbstractIfSymbolExists((*args_abs)[i]);
1422 MS_EXCEPTION_IF_NULL(abs);
1423 abs->SetSymbolicShape(symbolic_shape_list[i]);
1424 (*args_abs)[i] = abs;
1425 }
1426 }
1427
ReleaseResourceOnException(const py::object & phase)1428 void GraphExecutorPy::ReleaseResourceOnException(const py::object &phase) {
1429 bool clear = false;
1430 // Be sure the pointer res destroyed before do DelOneNetRes.
1431 {
1432 ResourcePtr res = GetResource(py::cast<std::string>(phase));
1433 if (res != nullptr) {
1434 clear = true;
1435 CleanCompileRes(res);
1436 }
1437 }
1438 ProcessStatus::GetInstance().Clear();
1439 if (clear) {
1440 DelOneNetRes(phase);
1441 }
1442 }
1443
Compile(const py::object & source,const py::tuple & args,const py::dict & kwargs,const py::object & phase,bool use_vm)1444 bool GraphExecutorPy::Compile(const py::object &source, const py::tuple &args, const py::dict &kwargs,
1445 const py::object &phase, bool use_vm) {
1446 bool res = false;
1447 HandleExceptionRethrow(
1448 [this, &res, &source, &args, &kwargs, &phase, use_vm]() {
1449 if (executor_running_) {
1450 MS_LOG(EXCEPTION) << "Nested execution during JIT execution for " << GetObjDesc(source) << " is not supported "
1451 << "when " << obj_desc_ << " compile and execute. For more details, please refer to "
1452 << "https://www.mindspore.cn/search?inputValue=Nested%20execution";
1453 }
1454 ProcessStatus::GetInstance().RecordStart(kCompiler);
1455 std::map<std::string, std::string> custom_info;
1456 custom_info["phase"] = py::cast<std::string>(phase);
1457 (void)profiler::CollectHostInfo(kCompiler, kCompiler, kCompiler, 1, 0, 0, custom_info);
1458 res = CompileInner(source, args, kwargs, phase, use_vm);
1459 (void)profiler::CollectHostInfo(kCompiler, kCompiler, kCompiler, 1, 0, 1, custom_info);
1460 ProcessStatus::GetInstance().RecordEnd();
1461 ProcessStatus::GetInstance().Print();
1462 },
1463 [this, &phase]() {
1464 if (!StaticAnalysisException::Instance().HasException()) {
1465 // print function call stack info before release
1466 std::string compile_exception_info = GetCompileExceptionInfo();
1467 if (!compile_exception_info.empty()) {
1468 MS_LOG(ERROR) << compile_exception_info;
1469 }
1470 }
1471 ReleaseResourceOnException(phase);
1472 },
1473 [this, &phase]() { ReleaseResourceOnException(phase); }, [this, &phase]() { ReleaseResourceOnException(phase); });
1474 return res;
1475 }
1476
CacheFuncGraph(const ResourcePtr & resource)1477 void CacheFuncGraph(const ResourcePtr &resource) {
1478 if (!resource->EnableCompileCache()) {
1479 return;
1480 }
1481 {
1482 MsProfileStatGuard stat_guard("SaveCacheFuncGraph");
1483 resource->CacheFuncGraph();
1484 }
1485 }
1486
CheckInterpretNodeLineInfos()1487 void CheckInterpretNodeLineInfos() {
1488 auto &py_interpret_nodes = InterpretNodeRecorder::GetInstance().PyInterpretNodes();
1489 auto &py_execute_nodes = InterpretNodeRecorder::GetInstance().PyExecuteNodes();
1490 if (py_interpret_nodes.empty() && py_execute_nodes.empty()) {
1491 return;
1492 }
1493
1494 std::stringstream ss;
1495 ss << "Found unsupported syntax in graph mode, those codes would be fallen back to Python interpreter:\n";
1496 // Dump for PyInterpret.
1497 ss << "----------------------------------------\n";
1498 ss << " After Parser Phase (total: " << py_interpret_nodes.size() << ")\n";
1499 ss << "----------------------------------------\n";
1500 size_t num = 1;
1501 for (const auto &node : py_interpret_nodes) {
1502 const auto line_info = trace::GetDebugInfoStr(node->debug_info());
1503 ss << "# No. " << num << ":\n" << line_info << "\n";
1504 ++num;
1505 }
1506 ss << "\n";
1507 // Dump for PyExecute.
1508 ss << "----------------------------------------\n";
1509 ss << " After Optimizer Phase (total: " << py_execute_nodes.size() << ")\n";
1510 ss << "----------------------------------------\n";
1511 num = 1;
1512 for (const auto &node : py_execute_nodes) {
1513 ss << "# No. " << num << ":\n";
1514 const auto &cnode = node->cast<CNodePtr>();
1515 MS_EXCEPTION_IF_NULL(cnode);
1516 const auto &weak_script_node = cnode->weak_input(1);
1517 const auto &script_node = weak_script_node.lock();
1518 MS_EXCEPTION_IF_NULL(script_node);
1519 const auto &script = GetValueNode<StringImmPtr>(script_node);
1520 // Usually the script is a value node.
1521 std::string script_str;
1522 if (script != nullptr) {
1523 script_str = script->value();
1524 } else {
1525 const auto &script_abs = script_node->abstract();
1526 if (script_abs != nullptr) {
1527 const auto script_abs_scalar = script_abs->cast<abstract::AbstractScalarPtr>();
1528 auto script_value = script_abs_scalar->BuildValue();
1529 MS_EXCEPTION_IF_NULL(script_value);
1530 auto script_value_str = script_value->cast<StringImmPtr>();
1531 MS_EXCEPTION_IF_NULL(script_value_str);
1532 script_str = script_value_str->value();
1533 }
1534 }
1535 if (!script_str.empty()) {
1536 ss << "Script: " << script_str << "\n\n";
1537 } else {
1538 ss << "Node: " << node->DebugString() << "\n\n";
1539 }
1540 const auto line_info = trace::GetDebugInfoStr(node->debug_info());
1541 ss << line_info << "\n";
1542 ++num;
1543 }
1544 ss << "\n";
1545 ss << "----------------------------------------\n";
1546
1547 // Print the codes run in JIT Fallback.
1548 if (common::GetEnv("MS_DEV_FALLBACK_DUMP_NODE") == "1") {
1549 MS_LOG(ERROR) << ss.str();
1550 } else {
1551 MS_LOG(INFO) << ss.str();
1552 }
1553 InterpretNodeRecorder::GetInstance().Clear();
1554 }
1555
1556 #ifdef ENABLE_DUMP_IR
RDRRecordGraph(const size_t action_index,const size_t action_size,const std::string & filename,const FuncGraphPtr & graph)1557 void RDRRecordGraph(const size_t action_index, const size_t action_size, const std::string &filename,
1558 const FuncGraphPtr &graph) {
1559 if (mindspore::RecorderManager::Instance().RdrEnable()) {
1560 MS_LOG(INFO) << "Recording FuncGraph in pipeline using RDR.";
1561 if (graph != nullptr) {
1562 auto graph_clone = BasicClone(graph);
1563 if (graph_clone != nullptr) {
1564 DumpGraphParams dump_params = {false, static_cast<int>(kTopStack)};
1565 if (action_index == action_size) {
1566 dump_params.dump_mode = static_cast<int>(kWholeStack);
1567 }
1568 (void)mindspore::RDR::RecordAnfGraph(SUBMODULE_ID, filename, graph_clone, dump_params, ".ir");
1569 } else {
1570 MS_LOG(WARNING) << "Clone FuncGraph failed in pipeline, no FuncGraph recording in RDR.";
1571 }
1572 } else {
1573 MS_LOG(WARNING) << "Pipeline Resource has no FuncGraph, no FuncGraph recording in RDR";
1574 }
1575 MS_LOG(INFO) << "Recording FuncGraph in pipeline end.";
1576 }
1577 }
1578 #endif
1579
1580 #ifdef ENABLE_DUMP_IR
RecordIR(const size_t action_index,const size_t action_size,const std::string & action_name,const FuncGraphPtr & graph,FuncGraphPtr * user_graph)1581 void RecordIR(const size_t action_index, const size_t action_size, const std::string &action_name,
1582 const FuncGraphPtr &graph, FuncGraphPtr *user_graph) {
1583 auto context = MsContext::GetInstance();
1584 MS_EXCEPTION_IF_NULL(context);
1585 if (context->CanDump(kIntroductory) && graph != nullptr) {
1586 *user_graph = graph;
1587 std::string base_name = GetBaseNameForIR(SizeToLong(action_index), action_name);
1588
1589 // Generate IR file in human-readable format
1590 static const auto switch_order = (common::GetEnv("MS_DEV_SAVE_GRAPHS_SORT_MODE") == "1");
1591 if (switch_order) {
1592 ExportIR(base_name + ".ir", graph);
1593 } else {
1594 DumpIR(base_name + ".ir", graph, true, kWholeStack);
1595 }
1596 if (context->CanDump(kFully)) {
1597 draw::Draw(base_name + ".dot", graph);
1598 }
1599 }
1600 }
1601 #endif
1602
1603 #ifndef ENABLE_SECURITY
SaveGraphForReadability(const std::string & action_name,const FuncGraphPtr & graph,const ResourcePtr & resource)1604 void SaveGraphForReadability(const std::string &action_name, const FuncGraphPtr &graph, const ResourcePtr &resource) {
1605 if (graph != nullptr && action_name.find("optimize") != string::npos) {
1606 #ifdef ENABLE_DUMP_IR
1607 auto context = MsContext::GetInstance();
1608 MS_EXCEPTION_IF_NULL(context);
1609 if (context->CanDump(kIntroductory)) {
1610 DumpIRProto(graph, action_name);
1611 }
1612 #endif
1613 resource->set_optimize_graph(graph);
1614 }
1615 }
1616 #endif
1617
Run()1618 void Pipeline::Run() {
1619 MS_LOG(INFO) << "Pipeline run";
1620 MS_EXCEPTION_IF_NULL(resource_);
1621 FuncGraphPtr user_graph = nullptr;
1622 const std::string last_compile_action = kValidate;
1623 bool already_print_profile = false;
1624 static const auto compile_profile_finish_action = common::GetCompileConfig("COMPILE_PROFILE_FINISH_ACTION");
1625 ProfileExecute(MsProfile::GetProfile(), [this, &user_graph, &last_compile_action, &already_print_profile]() {
1626 size_t i = 0;
1627 for (auto &action : actions_) {
1628 #ifdef ENABLE_TIMELINE
1629 DumpTime &dump_time = DumpTime::GetInstance();
1630 dump_time.Record(action.first, GetTime(), true);
1631 #endif
1632 ProcessStatus::GetInstance().RecordStart(action.first);
1633 (void)profiler::CollectHostInfo(kCompiler, action.first, action.first, 0, 0, 0);
1634 bool result = true;
1635 ProfileExecute(MsProfile::GetProfile()->Step(action.first), [&result, &action, this]() {
1636 MS_LOG(INFO) << "Status record: start " << action.first << " action.";
1637 result = action.second(resource_);
1638 MS_LOG(INFO) << "Status record: end " << action.first << " action.";
1639 if (IS_OUTPUT_ON(mindspore::kInfo)) {
1640 auto manager = resource_->func_graph()->manager();
1641 MS_EXCEPTION_IF_NULL(manager);
1642 MS_LOG(INFO) << "Extra status record: total func graphs: " << manager->func_graphs().size()
1643 << ", total nodes: " << manager->all_nodes().size();
1644 }
1645 });
1646 (void)profiler::CollectHostInfo(kCompiler, action.first, action.first, 0, 0, 1);
1647 ProcessStatus::GetInstance().RecordEnd();
1648 if (!result) {
1649 MS_LOG(INTERNAL_EXCEPTION) << "Pipeline running to end, failed in step:" << action.first;
1650 }
1651
1652 if (EnabledProfile() && compile_profile_finish_action == action.first) {
1653 ProfileExecuteBreak(MsProfile::GetProfile());
1654 MsProfile::Print();
1655 already_print_profile = true;
1656 }
1657
1658 if (action.first == kTaskEmit) {
1659 SetLoopCount(resource_);
1660 } else if (action.first == last_compile_action) {
1661 CheckInterpretNodeLineInfos();
1662 CacheFuncGraph(resource_);
1663 #ifndef ENABLE_SECURITY
1664 #ifdef WITH_BACKEND
1665 MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
1666 if (MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
1667 const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1668 {kAscendDevice, MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
1669 MS_EXCEPTION_IF_NULL(device_context);
1670 MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
1671 device_context->GetDeprecatedInterface()->DumpProfileParallelStrategy(resource_->func_graph());
1672 }
1673 #endif
1674 #endif
1675 ResetId(resource_);
1676 }
1677 FuncGraphPtr graph = resource_->func_graph();
1678 #ifdef ENABLE_DUMP_IR
1679 std::string filename = GetBaseNameForIR(SizeToLong(i), action.first);
1680 RDRRecordGraph(i, actions_.size(), filename, graph);
1681 RecordIR(i, actions_.size(), action.first, graph, &user_graph);
1682 #endif
1683 #ifndef ENABLE_SECURITY
1684 SaveGraphForReadability(action.first, graph, resource_);
1685 #endif
1686 i++;
1687 #ifdef ENABLE_TIMELINE
1688 dump_time.Record(action.first, GetTime(), false);
1689 #endif
1690 }
1691 });
1692
1693 if (EnabledProfile()) {
1694 if (!already_print_profile) {
1695 MsProfile::Print();
1696 }
1697 MsProfile::Reset();
1698 }
1699
1700 #ifdef ENABLE_DUMP_IR
1701 auto context = MsContext::GetInstance();
1702 MS_EXCEPTION_IF_NULL(context);
1703 if (context->CanDump(kIntroductory) && (user_graph != nullptr)) {
1704 if (context->CanDump(kFully)) {
1705 draw::DrawUserFuncGraph("ModelDigraph.dot", user_graph);
1706 }
1707 }
1708 if (common::GetEnv("DUMP_PARALLEL_INFO") == "1") {
1709 std::unordered_map<std::string, std::vector<uint32_t>> group_map;
1710 if (distributed::collective::CollectiveManager::instance()->initialized()) {
1711 group_map = distributed::collective::CollectiveManager::instance()->get_group_map();
1712 }
1713 if (parallel::g_device_manager == nullptr) {
1714 MS_LOG(WARNING) << "parallel::g_device_manager is not initialized. Skip dump parallel info.";
1715 } else {
1716 auto global_rank_id = parallel::g_device_manager->global_rank();
1717 DumpParallelJson("dump_parallel_info_" + std::to_string(global_rank_id) + ".json", resource_->func_graph(),
1718 global_rank_id, group_map);
1719 }
1720 }
1721 #endif
1722 MS_LOG(INFO) << "End";
1723 }
1724
NeedCreateBackend()1725 bool Pipeline::NeedCreateBackend() {
1726 return std::any_of(actions_.begin(), actions_.end(),
1727 [](const ActionItem &action) { return action.first == kTaskEmit || action.first == kExecute; });
1728 }
1729
ProcessVmArgInner(const py::tuple & args,const ResourcePtr & res,VectorRef * const arg_list)1730 void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef *const arg_list) {
1731 MS_EXCEPTION_IF_NULL(arg_list);
1732 bool arg_list_inited = !arg_list->empty();
1733 for (std::size_t i = 0; i < args.size(); i++) {
1734 py::object arg = args[i];
1735 ValuePtr converted = nullptr;
1736 bool succ = parse::ConvertData(arg, &converted);
1737 if (!succ) {
1738 MS_LOG(INTERNAL_EXCEPTION) << "The " << i << "th arg convert failed.";
1739 }
1740 if (!arg_list_inited) {
1741 arg_list->push_back(converted);
1742 continue;
1743 }
1744 if (i >= arg_list->size()) {
1745 MS_LOG(INTERNAL_EXCEPTION) << "i:" << i << " output of range:" << arg_list->size();
1746 }
1747 (*arg_list)[i] = converted;
1748 }
1749
1750 MS_EXCEPTION_IF_NULL(res);
1751 auto graph = res->func_graph();
1752 MS_EXCEPTION_IF_NULL(graph);
1753 const std::vector<AnfNodePtr> &graph_params = graph->parameters();
1754 std::size_t graph_params_size = graph_params.size();
1755 if ((*arg_list).size() != graph_params_size) {
1756 // Maybe some default parameter
1757 for (std::size_t i = (*arg_list).size(); i < graph_params_size; i++) {
1758 MS_EXCEPTION_IF_NULL(graph_params[i]);
1759 auto param_ptr = (graph_params[i])->cast_ptr<Parameter>();
1760 MS_EXCEPTION_IF_NULL(param_ptr);
1761 if (!param_ptr->has_default()) {
1762 MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
1763 }
1764 if (!param_ptr->default_param()->isa<Tensor>()) {
1765 MS_LOG(EXCEPTION) << "Parameter[" << param_ptr->ToString()
1766 << "] is not initialized, need to call `.init_data()`";
1767 }
1768 arg_list->push_back(param_ptr->default_param());
1769 }
1770 }
1771 }
1772
ProcessVmArg(const py::tuple & args,const std::string & phase,VectorRef * const arg_list)1773 void GraphExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *const arg_list) {
1774 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kGraphExecutorPy, runtime::ProfilerEvent::kInputProcess,
1775 phase);
1776 ProcessVmArgInner(args, GetResource(phase), arg_list);
1777 }
1778
1779 #ifdef ENABLE_DEBUGGER
TerminateDebugger()1780 void GraphExecutorPy::TerminateDebugger() {
1781 if (Common::GetDebugTerminate()) {
1782 MS_LOG(INFO) << "Terminate debugger and clear resources!";
1783 ClearResAtexit();
1784 exit(static_cast<int>(!Common::GetDebugExitSuccess()));
1785 }
1786 }
1787 #endif
1788
Run(const py::tuple & args,const py::object & phase)1789 py::object GraphExecutorPy::Run(const py::tuple &args, const py::object &phase) {
1790 py::object res;
1791 HandleExceptionRethrow(
1792 [this, &res, &args, &phase]() {
1793 executor_running_ = true;
1794
1795 uint64_t start_time = 0;
1796 PROFILER_START(start_time);
1797 res = RunInner(args, phase);
1798 PROFILER_STAGE_END(start_time, runtime::ProfilerStage::kRunGraph);
1799
1800 executor_running_ = false;
1801 },
1802 [this]() { executor_running_ = false; }, [this]() { executor_running_ = false; },
1803 [this]() { executor_running_ = false; }, nullptr, true);
1804 return res;
1805 }
1806
1807 #ifdef WITH_BACKEND
GeFirstInitParams()1808 void GraphExecutorPy::GeFirstInitParams() {
1809 static bool inited = false;
1810 if (!inited) {
1811 MS_LOG(INFO) << "Start init params.";
1812 const auto &init_params = GetParams(phase_);
1813 auto ret = InitParams(init_params, phase_);
1814 if (ret) {
1815 inited = true;
1816 }
1817 }
1818 }
1819 #endif
1820
ClearRunArgumentsResource(size_t input_arg_size,VectorRef * arg_list)1821 void GraphExecutorPy::ClearRunArgumentsResource(size_t input_arg_size, VectorRef *arg_list) {
1822 for (std::size_t i = 0; i < input_arg_size; ++i) {
1823 (*arg_list)[i] = nullptr;
1824 }
1825 }
1826
RunInner(const py::tuple & args,const py::object & phase_obj)1827 py::object GraphExecutorPy::RunInner(const py::tuple &args, const py::object &phase_obj) {
1828 if (common::GetEnv(kSimulationLevel) == kSimulationLevelCompileGraph) {
1829 py::int_ ret = 0;
1830 return ret;
1831 }
1832 // Init for dynamic-obfuscated model infer
1833 (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
1834 // Mindspore debugger notify main thread to exit after one step, and will not run next step
1835 #ifdef ENABLE_DEBUGGER
1836 TerminateDebugger();
1837 #endif
1838 if (!py::isinstance<py::str>(phase_obj)) {
1839 MS_LOG(INTERNAL_EXCEPTION) << "Run failed, phase input is not a str";
1840 }
1841 auto phase = py::cast<std::string>(phase_obj);
1842 auto phase_prefix = GetPhasePrefix(phase);
1843 PhaseManager::GetInstance().set_phase(phase_prefix);
1844 auto ms_context = MsContext::GetInstance();
1845 MS_EXCEPTION_IF_NULL(ms_context);
1846 static const bool enable_infer_boost = ms_context->IsEnableInferBoost();
1847 if (enable_infer_boost) {
1848 PhaseManager::GetInstance().set_phase(phase);
1849 }
1850 #ifdef WITH_BACKEND
1851 if (ms_context->backend_policy() == "ge") {
1852 if (!IsEnableRefMode()) {
1853 GeFirstInitParams();
1854 }
1855
1856 if (phase_prefix == "save") {
1857 auto pos = phase.find('.');
1858 std::string origin_phase = phase.substr(pos + 1);
1859 FuncGraphPtr func_graph = info_["train." + origin_phase]->func_graph;
1860 MS_EXCEPTION_IF_NULL(func_graph);
1861 MS_EXCEPTION_IF_NULL(MsContext::GetInstance());
1862 auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
1863 {MsContext::GetInstance()->get_param<std::string>(MS_CTX_DEVICE_TARGET),
1864 MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
1865 MS_EXCEPTION_IF_NULL(device_context);
1866 MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
1867 device_context->GetDeprecatedInterface()->DoExecNonInputGraph("save." + func_graph->ToString());
1868 ConfigManager::GetInstance().ResetConfig();
1869 return py::none();
1870 }
1871 }
1872 #endif
1873 auto ret_val = std::make_shared<py::object>();
1874 if (info_.count(phase) != 0 && info_[phase]->func_graph != nullptr) {
1875 if (IsGraphOutputValueNodeOrParameter(info_[phase]->func_graph->output(), args, ret_val)) {
1876 return *ret_val;
1877 }
1878 }
1879 #ifndef WITH_BACKEND
1880 if (ms_context->backend_policy() == "ge") {
1881 // Virtual output constructed for test cases.
1882 if (!args.empty()) {
1883 return args[0];
1884 }
1885 return args;
1886 }
1887 #endif
1888 auto iter = info_.find(phase);
1889 if (iter == info_.end()) {
1890 MS_LOG(INTERNAL_EXCEPTION) << "No executor info. found for phase: " << phase;
1891 }
1892 auto &execute_info = iter->second;
1893 MS_EXCEPTION_IF_NULL(execute_info);
1894 if (args.size() > execute_info->arg_list_size) {
1895 MS_LOG(WARNING) << "The args size: " << args.size() << ", full_arg_size: " << execute_info->arg_list_size;
1896 }
1897 ProcessVmArg(args, phase, &execute_info->arg_list);
1898 // Start to run phase.
1899 compile::VmEvalFuncPtr run = GetVmEvalFunc(phase);
1900 if (run == nullptr) {
1901 MS_LOG(INTERNAL_EXCEPTION) << "Can't find run graph func for " << phase;
1902 }
1903
1904 MS_LOG(DEBUG) << "Eval run " << ms_context->backend_policy();
1905 const auto &output = execute_info->func_graph->output();
1906 MS_EXCEPTION_IF_NULL(output);
1907 const auto &output_abs = output->abstract();
1908 MS_EXCEPTION_IF_NULL(output_abs);
1909 BaseRef value = (*run)(execute_info->arg_list);
1910 bool need_recovery = distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() &&
1911 distributed::recovery::RecoveryContext::GetInstance()->need_reset();
1912 if (need_recovery) {
1913 // In recovery scenario, the output value could be empty, do not transform return data.
1914 return py::none();
1915 }
1916 py::object res = BaseRefToPyDataWithUserData(value, output_abs);
1917 ClearRunArgumentsResource(args.size(), &execute_info->arg_list);
1918 PhaseManager::GetInstance().ClearPhase();
1919 MS_LOG(DEBUG) << "Run end";
1920 return res;
1921 } // namespace pipeline
1922
InitParams(const py::dict & init_params,const std::string & phase) const1923 bool GraphExecutorPy::InitParams(const py::dict &init_params, const std::string &phase) const {
1924 MS_LOG(INFO) << "Init params when ge backend, phase = " << phase;
1925 if (info_.count(phase) == 0) {
1926 MS_LOG(INTERNAL_EXCEPTION) << "No phase in executor: " << GetPhasePrefix(phase);
1927 }
1928 DeviceContext *device_context = nullptr;
1929 try {
1930 auto ms_context = MsContext::GetInstance();
1931 MS_EXCEPTION_IF_NULL(ms_context);
1932 auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1933 device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({kAscendDevice, device_id});
1934 } catch (const std::exception &) {
1935 return false;
1936 }
1937 MS_EXCEPTION_IF_NULL(device_context);
1938 MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
1939 return device_context->GetDeprecatedInterface()->RunInitGraph(info_.at(phase)->func_graph, init_params);
1940 }
1941
BuildGraph(const py::dict & init_params,const std::string & phase) const1942 FuncGraphPtr GraphExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase) const {
1943 MS_LOG(INFO) << "Start build df graph, phase = " << phase;
1944 if (info_.count(phase) == 0) {
1945 MS_LOG(INTERNAL_EXCEPTION) << "No phase in executor: " << GetPhasePrefix(phase);
1946 }
1947 DeviceContext *device_context = nullptr;
1948 try {
1949 auto ms_context = MsContext::GetInstance();
1950 MS_EXCEPTION_IF_NULL(ms_context);
1951 auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
1952 device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({kAscendDevice, device_id});
1953 } catch (const std::exception &) {
1954 return nullptr;
1955 }
1956 MS_EXCEPTION_IF_NULL(device_context);
1957 MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
1958 return device_context->GetDeprecatedInterface()->BuildDFGraph(info_.at(phase)->func_graph, init_params);
1959 }
1960
UpdataParamNodeDefaultInput(const std::string & phase,const std::unordered_map<std::string,tensor::TensorPtr> & params_value)1961 void GraphExecutorPy::UpdataParamNodeDefaultInput(
1962 const std::string &phase, const std::unordered_map<std::string, tensor::TensorPtr> ¶ms_value) {
1963 FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
1964 MS_EXCEPTION_IF_NULL(func_graph);
1965 MS_LOG(DEBUG) << "UpdataParamNodeDefaultInput for func graph(" << func_graph->ToString() << ") phase(" << phase
1966 << ")!";
1967 auto ¶ms = func_graph->parameters();
1968 for (const auto ¶m : params) {
1969 MS_EXCEPTION_IF_NULL(param);
1970 auto param_cast = param->cast_ptr<Parameter>();
1971 MS_EXCEPTION_IF_NULL(param_cast);
1972 auto iter = params_value.find(param_cast->name());
1973 if (iter != params_value.end()) {
1974 param_cast->set_default_param(iter->second);
1975 }
1976 }
1977 }
1978
GetParams(const std::string & phase)1979 py::dict GraphExecutorPy::GetParams(const std::string &phase) {
1980 FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
1981 MS_EXCEPTION_IF_NULL(func_graph);
1982 py::dict parameter_dict;
1983 std::vector<AnfNodePtr> graph_params = func_graph->parameters();
1984 for (auto ¶m : graph_params) {
1985 MS_EXCEPTION_IF_NULL(param);
1986 auto param_ptr = std::static_pointer_cast<Parameter>(param);
1987 std::string name = param_ptr->name();
1988 auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_ptr->default_param());
1989 if (tensor != nullptr) {
1990 parameter_dict[py::str(name)] = *tensor;
1991 }
1992 }
1993 return parameter_dict;
1994 }
1995
GetRandomStatus(const std::string & phase) const1996 py::bytes GraphExecutorPy::GetRandomStatus(const std::string &phase) const {
1997 auto iter = info_.find(phase);
1998 if (iter == info_.end()) {
1999 MS_LOG(ERROR) << "Phase " << phase << " must compile.";
2000 return "";
2001 }
2002 MS_EXCEPTION_IF_NULL(iter->second);
2003 MS_EXCEPTION_IF_NULL(iter->second->resource);
2004 auto &resource = iter->second->resource;
2005 auto backend = resource->GetBackend();
2006 const auto &mindrt_backend = std::dynamic_pointer_cast<compile::MindRTBackend>(backend);
2007 MS_EXCEPTION_IF_NULL(mindrt_backend);
2008 auto actor_info = resource->GetResult(kActorInfo).cast<compile::ActorInfo>();
2009 auto random_status = mindrt_backend->GetRandomStatus(actor_info);
2010 return py::bytes(random_status.c_str(), random_status.size());
2011 }
2012
PyExePath(const py::object & py_exe_path) const2013 void GraphExecutorPy::PyExePath(const py::object &py_exe_path) const {
2014 if (!py::isinstance<py::str>(py_exe_path)) {
2015 MS_LOG(INTERNAL_EXCEPTION) << "Failed, py_exe_path input is not a str";
2016 }
2017 auto py_exe_path_s = py::cast<std::string>(py_exe_path);
2018 auto ms_context = MsContext::GetInstance();
2019 ms_context->set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, py_exe_path_s);
2020 }
2021
KernelBuildServerDir(const py::object & kernel_build_server_dir) const2022 void GraphExecutorPy::KernelBuildServerDir(const py::object &kernel_build_server_dir) const {
2023 if (!py::isinstance<py::str>(kernel_build_server_dir)) {
2024 MS_LOG(INTERNAL_EXCEPTION) << "Failed, kernel_build_server_dir input is not a str";
2025 }
2026 auto kernel_build_server_dir_s = py::cast<std::string>(kernel_build_server_dir);
2027 auto ms_context = MsContext::GetInstance();
2028 ms_context->set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, kernel_build_server_dir_s);
2029 }
2030
InitExecDataset(const std::string & queue_name,int64_t iter_num,int64_t batch_size,const std::vector<TypePtr> & types,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int64_t> & input_indexes,const std::string &,bool need_run)2031 bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size,
2032 const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
2033 const std::vector<int64_t> &input_indexes, const std::string &, bool need_run) {
2034 auto ms_context = MsContext::GetInstance();
2035 MS_EXCEPTION_IF_NULL(ms_context);
2036 std::string name = ms_context->backend_policy();
2037 #ifdef WITH_BACKEND
2038 if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
2039 auto device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
2040 {kAscendDevice, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
2041 MS_EXCEPTION_IF_NULL(device_context);
2042 MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
2043 if (!device_context->GetDeprecatedInterface()->IsTsdOpened(ms_context)) {
2044 InitPipeline();
2045 }
2046 }
2047 #endif
2048
2049 if (name == kMsConvert || name == kMsVm || name == "ge") {
2050 #ifdef WITH_BACKEND
2051 if (iter_num == -1) {
2052 iter_num = INT32_MAX;
2053 }
2054 bool status = InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes, need_run);
2055 return status;
2056 #endif
2057 }
2058 return name == "ge" ? true : false;
2059 }
2060
InitExecDatasetVm(const std::string & queue_name,int64_t size,int64_t batch_size,const std::vector<TypePtr> & types,const std::vector<std::vector<int64_t>> & shapes,const std::vector<int64_t> & input_indexes,bool need_run)2061 bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size,
2062 const std::vector<TypePtr> &types, const std::vector<std::vector<int64_t>> &shapes,
2063 const std::vector<int64_t> &input_indexes, bool need_run) {
2064 #if defined(__linux__) && defined(WITH_BACKEND)
2065 if (ps::PSContext::instance()->is_ps_mode() && ps::PSContext::instance()->cache_enable() &&
2066 !ps::PSContext::instance()->is_worker()) {
2067 return true;
2068 }
2069 #endif
2070 MS_LOG(INFO) << "Start InitDataSet Entry";
2071 mindspore::python_adapter::set_python_env_flag(true);
2072 ShapeVector int_input_indexes;
2073 (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes),
2074 [](int64_t item) { return static_cast<int64_t>(item); });
2075 std::vector<ShapeVector> int_shapes;
2076 (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes),
2077 [](const std::vector<int64_t> &item) {
2078 ShapeVector vector_item;
2079 (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item),
2080 [](int64_t inner_item) { return static_cast<int64_t>(inner_item); });
2081 return vector_item;
2082 });
2083 auto p_init = std::make_shared<Primitive>("InitDataSetQueue");
2084 p_init->set_attr("queue_name", MakeValue(queue_name));
2085 p_init->set_attr("size", MakeValue(static_cast<int64_t>(size)));
2086 p_init->set_attr("batch_size", MakeValue(static_cast<int64_t>(batch_size)));
2087 p_init->set_attr("types", MakeValue(types));
2088 p_init->set_attr("shapes", MakeValue(int_shapes));
2089 p_init->set_attr("input_indexes", MakeValue(int_input_indexes));
2090
2091 const std::vector<std::string> empty_str_list;
2092 p_init->set_attr("input_names", MakeValue(empty_str_list));
2093 p_init->set_attr("output_names", MakeValue(empty_str_list));
2094
2095 FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
2096 auto app_init = std::make_shared<CNode>(AnfNodeWeakPtrList({NewValueNode(p_init)}), func_graph);
2097 func_graph->set_output(app_init);
2098 auto manager = MakeManager();
2099 manager->AddFuncGraph(func_graph);
2100
2101 // AbstractNone indicates there is no output for this apply node.
2102 auto abstract_none = std::make_shared<abstract::AbstractNone>();
2103 app_init->set_abstract(abstract_none);
2104 // Before the graph compiling, need reset the iter num.
2105 ConfigManager::GetInstance().ResetIterNum();
2106 #ifdef ENABLE_DUMP_IR
2107 mindspore::RDR::ResetRecorder();
2108 #endif
2109
2110 compile::SetMindRTEnable();
2111 auto backend = compile::CreateBackend();
2112 MS_EXCEPTION_IF_NULL(backend);
2113 auto context_ptr = MsContext::GetInstance();
2114 MS_EXCEPTION_IF_NULL(context_ptr);
2115 // The data set graph compiling and running of mindRT.
2116 if (context_ptr->get_param<bool>(MS_CTX_ENABLE_MINDRT)) {
2117 #if defined(__linux__) && defined(WITH_BACKEND)
2118 if (ps::PSContext::instance()->is_worker() && ps::PSContext::instance()->cache_enable()) {
2119 distributed::DataQueueManager::GetInstance().CreateDataQueue(queue_name, size, 128);
2120 }
2121 #endif
2122
2123 const auto &mindrt_backend = std::dynamic_pointer_cast<compile::MindRTBackend>(backend);
2124 MS_EXCEPTION_IF_NULL(mindrt_backend);
2125 SetRunMode(func_graph, mindrt_backend.get());
2126 auto &actor_info = mindrt_backend->CompileGraphs(func_graph);
2127 VectorRef args;
2128 if (need_run) {
2129 VectorRef outputs;
2130 mindrt_backend->RunGraph(actor_info, args, &outputs);
2131 }
2132 ConfigManager::GetInstance().set_iter_num(queue_name, size);
2133 return true;
2134 }
2135
2136 auto convert_fn = backend->convert_fn();
2137 MS_EXCEPTION_IF_NULL(convert_fn);
2138 // Convert CNodeList to LinConvertResult.
2139 auto segment = std::make_shared<GraphSegment>(std::vector<AnfNodePtr>{app_init}, false);
2140 auto runner = convert_fn(segment, "");
2141 ConfigManager::GetInstance().set_iter_num(queue_name, size);
2142
2143 if (!(*runner.run)) {
2144 // empty function
2145 MS_LOG(EXCEPTION) << "Backend " << backend->name() << " unsupported tdt dataset.";
2146 }
2147
2148 // launch init dataset runner without inputs and outputs
2149 VectorRef args;
2150 auto fn = runner.run;
2151 if (need_run) {
2152 (void)(*fn)(args);
2153 }
2154 MS_LOG(DEBUG) << "InitDataSetVm End.";
2155 return true;
2156 }
2157
GetJitLevel()2158 std::string GetJitLevel() {
2159 const auto &jit_config = PhaseManager::GetInstance().jit_config();
2160 auto iter = jit_config.find("jit_level");
2161 if (iter != jit_config.end()) {
2162 return iter->second;
2163 }
2164 return "";
2165 }
2166
ResetOpId()2167 void ResetOpId() { mindspore::id_generator::reset_id(); }
ResetOpIdWithOffset()2168 void ResetOpIdWithOffset() { mindspore::id_generator::reset_id_with_offset(); }
2169
InitHccl()2170 void InitHccl() {
2171 auto ms_context = MsContext::GetInstance();
2172 MS_EXCEPTION_IF_NULL(ms_context);
2173 ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
2174 #ifdef WITH_BACKEND
2175 auto backend = ms_context->backend_policy();
2176 if (backend == "ge") {
2177 if (!mindspore::distributed::Initialize()) {
2178 MS_LOG(EXCEPTION) << "InitHccl failed.";
2179 }
2180 InitPipeline();
2181 return;
2182 }
2183 #endif
2184 mindspore::python_adapter::set_python_env_flag(true);
2185 std::string device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2186 if (ms_context->backend_policy() == "ms" && device_name == kAscendDevice) {
2187 if (!mindspore::distributed::Initialize()) {
2188 MS_LOG(EXCEPTION) << "InitHccl failed.";
2189 }
2190 }
2191 }
2192
FinalizeHccl()2193 void FinalizeHccl() {
2194 auto ms_context = MsContext::GetInstance();
2195 MS_EXCEPTION_IF_NULL(ms_context);
2196 #ifdef WITH_BACKEND
2197 auto backend = ms_context->backend_policy();
2198 if (backend == "ge") {
2199 FinalizeBackend();
2200 return;
2201 }
2202 #endif
2203 session::ExecutorManager::Instance().Clear();
2204 device::KernelRuntimeManager::Instance().ClearRuntimeResource();
2205 device::DeviceContextManager::GetInstance().ClearDeviceContexts();
2206 device::DeviceContextManager::GetInstance().UnloadPlugin();
2207 }
2208
GetHcclRankId()2209 uint32_t GetHcclRankId() {
2210 uint32_t rank_id = 0;
2211 bool ret = CommManager::GetInstance().GetRankID("", &rank_id);
2212 if (!ret) {
2213 MS_LOG(ERROR) << "Get rank id failed, return rank id " << rank_id << " as default.";
2214 }
2215 return rank_id;
2216 }
2217
GetHcclRankSize()2218 uint32_t GetHcclRankSize() {
2219 uint32_t rank_size = 0;
2220 bool ret = CommManager::GetInstance().GetRankSize("", &rank_size);
2221 if (!ret) {
2222 MS_LOG(ERROR) << "Get rank size failed, return rank size " << rank_size << " as default.";
2223 }
2224 return rank_size;
2225 }
2226
ExportGraph(const std::string & file_name,const std::string & phase,const py::object encrypt,char * key)2227 void GraphExecutorPy::ExportGraph(const std::string &file_name, const std::string &phase, const py::object encrypt,
2228 char *key) {
2229 DeviceContext *device_context = nullptr;
2230 try {
2231 auto ms_context = MsContext::GetInstance();
2232 MS_EXCEPTION_IF_NULL(ms_context);
2233 auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
2234 device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({kAscendDevice, device_id});
2235 } catch (const std::exception &) {
2236 MS_EXCEPTION(ValueError) << "Only support export file in 'AIR' format with Ascend backend.";
2237 }
2238 MS_EXCEPTION_IF_NULL(device_context);
2239 MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
2240 FuncGraphPtr func_graph = info_[phase]->func_graph;
2241 MS_EXCEPTION_IF_NULL(func_graph);
2242 device_context->GetDeprecatedInterface()->ExportDFGraph(file_name, func_graph->ToString(), encrypt, key);
2243 }
2244
LoadMindIR(const std::string & file_name,const char * dec_key,const size_t key_len,const std::string & dec_mode,const py::object decrypt,const bool obfuscated)2245 FuncGraphPtr LoadMindIR(const std::string &file_name, const char *dec_key, const size_t key_len,
2246 const std::string &dec_mode, const py::object decrypt, const bool obfuscated) {
2247 if (obfuscated) {
2248 MS_LOG(DEBUG) << "[LoadMindIR] Set customized function.";
2249 (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
2250 (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().init_calling_count();
2251 }
2252 FuncGraphPtr func_graph = nullptr;
2253 if (dec_mode == "Customized") {
2254 py::bytes key_bytes(dec_key);
2255 py::bytes model_stream = decrypt(file_name, key_bytes);
2256 std::string model_string(model_stream);
2257
2258 MindIRLoader mindir_loader;
2259 func_graph = mindir_loader.LoadMindIR(model_string.c_str(), model_string.size());
2260 } else {
2261 MindIRLoader mindir_loader(false, reinterpret_cast<const unsigned char *>(dec_key), key_len, dec_mode, false);
2262 func_graph = mindir_loader.LoadMindIR(file_name);
2263 }
2264 #ifdef ENABLE_DUMP_IR
2265 auto context = MsContext::GetInstance();
2266 MS_EXCEPTION_IF_NULL(context);
2267 if (context->CanDump(kIntroductory)) {
2268 DumpIR("load.ir", func_graph);
2269 }
2270 #endif
2271 return func_graph;
2272 }
2273
SplitMindIR(const std::string & file_name)2274 FuncGraphPtr SplitMindIR(const std::string &file_name) {
2275 MS_LOG(INFO) << "Start split mindir";
2276 FuncGraphPtr func_graph = nullptr;
2277 MindIRLoader mindir_loader;
2278 func_graph = mindir_loader.LoadMindIR(file_name);
2279 if (func_graph == nullptr) {
2280 MS_LOG(ERROR) << "Load MindIR file failed. Please check model file.";
2281 return nullptr;
2282 }
2283 #ifdef ENABLE_DUMP_IR
2284 auto context = MsContext::GetInstance();
2285 MS_EXCEPTION_IF_NULL(context);
2286 if (context->CanDump(kIntroductory)) {
2287 DumpIR("load.ir", func_graph);
2288 }
2289 #endif
2290 auto ms_context = MsContext::GetInstance();
2291 MS_EXCEPTION_IF_NULL(ms_context);
2292 auto parallel_context = parallel::ParallelContext::GetInstance();
2293 parallel_context->Reset();
2294 parallel_context->set_parallel_mode(parallel::kAutoParallel);
2295 parallel_context->set_strategy_search_mode(parallel::kRecursiveProgramming);
2296 parallel_context->set_direct_split(true);
2297 parallel_context->set_full_batch(true);
2298 parallel_context->set_group_ckpt_save_file("group_info");
2299
2300 FuncGraphManagerPtr func_graph_manager = func_graph->manager();
2301
2302 MS_LOG(INFO) << "func_graph_manager is not null";
2303 if (func_graph_manager == nullptr) {
2304 std::vector<FuncGraphPtr> graphs{func_graph};
2305 func_graph_manager = std::make_shared<FuncGraphManager>(graphs);
2306 func_graph_manager->AddFuncGraph(func_graph);
2307 }
2308 pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
2309 resource->set_manager(func_graph_manager);
2310
2311 // Get the parameters items and add the value to args_abs.
2312 auto params = func_graph->parameters();
2313 auto inputs = func_graph->get_inputs();
2314 for (std::size_t i = 0; i < inputs.size(); i++) {
2315 auto input = inputs[i]->abstract();
2316 (void)parallel::ExtendInputArgsAbstractShape(input, i);
2317 }
2318 parallel::StepAutoParallel(func_graph, NULL);
2319 parallel::StepParallel(func_graph, NULL);
2320 parallel::StepAllreduceFusion(func_graph, NULL);
2321 resource->set_func_graph(func_graph);
2322 resource->set_manager(func_graph->manager());
2323 opt::irpass::OptimizeIRPassLib irpass;
2324 opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
2325 opt::OptPassConfig virtual_output = opt::OptPassConfig({irpass.virtual_output_eliminate_});
2326
2327 opt::OptPassGroupMap map_parallel_eliminate(
2328 {{"virtual_dataset", virtual_dataset}, {"virtual_output", virtual_output}});
2329
2330 auto split_pass_opts = opt::Optimizer::MakeOptimizer("map_parallel_eliminate", resource, map_parallel_eliminate);
2331 ProfileExecute(MsProfile::GetProfile()->Step("split_pass_opts"),
2332 [&split_pass_opts, &func_graph]() { func_graph = split_pass_opts->step(func_graph, true); });
2333
2334 AbstractBasePtrList args_abs_list;
2335 (void)std::transform(params.begin(), params.end(), std::back_inserter(args_abs_list),
2336 [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
2337 func_graph = pipeline::Renormalize(resource, func_graph, args_abs_list);
2338
2339 resource->set_args_abs(args_abs_list);
2340
2341 MindIRExporter mindir_exporter;
2342 mindir_exporter.ExportProto(func_graph, "split_net", nullptr);
2343
2344 parallel::HandleGroupInfo();
2345
2346 return func_graph;
2347 }
2348
SplitDynamicMindIR(const std::string & file_name,size_t device_num,size_t rank_id,bool sapp)2349 FuncGraphPtr SplitDynamicMindIR(const std::string &file_name, size_t device_num, size_t rank_id, bool sapp) {
2350 MS_LOG(INFO) << "Start split dynamic mindir for transformer network";
2351 FuncGraphPtr func_graph = nullptr;
2352 MindIRLoader mindir_loader;
2353 func_graph = mindir_loader.LoadMindIR(file_name);
2354 if (func_graph == nullptr) {
2355 MS_LOG(ERROR) << "Load MindIR file failed. Please check model file.";
2356 return nullptr;
2357 }
2358 #ifdef ENABLE_DUMP_IR
2359 auto context = MsContext::GetInstance();
2360 MS_EXCEPTION_IF_NULL(context);
2361 if (context->CanDump(kIntroductory)) {
2362 DumpIR("load.ir", func_graph);
2363 }
2364 #endif
2365 auto ms_context = MsContext::GetInstance();
2366 MS_EXCEPTION_IF_NULL(ms_context);
2367 auto parallel_context = parallel::ParallelContext::GetInstance();
2368 parallel_context->Reset();
2369 parallel_context->set_parallel_mode(parallel::kAutoParallel);
2370 parallel_context->set_strategy_search_mode(parallel::kRecursiveProgramming);
2371 parallel_context->set_direct_split(true);
2372 parallel_context->set_full_batch(true);
2373 parallel_context->set_group_ckpt_save_file("group_info");
2374
2375 for (size_t rank_id_iter = 0; rank_id_iter < device_num; rank_id_iter++) {
2376 auto tmp_func_graph = mindspore::BasicClone(func_graph);
2377 FuncGraphManagerPtr func_graph_manager = tmp_func_graph->manager();
2378
2379 if (func_graph_manager == nullptr) {
2380 MS_LOG(INFO) << "func_graph_manager is null";
2381 std::vector<FuncGraphPtr> graphs{tmp_func_graph};
2382 func_graph_manager = std::make_shared<FuncGraphManager>(graphs);
2383 func_graph_manager->AddFuncGraph(tmp_func_graph);
2384 }
2385
2386 auto inputs = tmp_func_graph->get_inputs();
2387 for (std::size_t i = 0; i < inputs.size(); i++) {
2388 auto input = inputs[i]->abstract();
2389 (void)parallel::ExtendInputArgsAbstractShape(input, i);
2390 }
2391
2392 auto res = parallel::StepAssignedParallel(tmp_func_graph, func_graph_manager, device_num, rank_id_iter, sapp);
2393 if (!res) {
2394 MS_LOG(ERROR) << "StepAssignedParallel failed. Please check.";
2395 return nullptr;
2396 }
2397 pipeline::ResourcePtr resource = std::make_shared<pipeline::Resource>();
2398 resource->set_is_load(false);
2399 resource->set_manager(func_graph_manager);
2400 resource->set_func_graph(tmp_func_graph);
2401 // Get the parameters items and add the value to args_abs.
2402 auto params = tmp_func_graph->parameters();
2403 AbstractBasePtrList args_abs_list;
2404 (void)std::transform(params.begin(), params.end(), std::back_inserter(args_abs_list),
2405 [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
2406 tmp_func_graph = pipeline::Renormalize(resource, tmp_func_graph, args_abs_list);
2407
2408 #ifdef ENABLE_DUMP_IR
2409 auto re_context = MsContext::GetInstance();
2410 MS_EXCEPTION_IF_NULL(re_context);
2411 if (re_context->CanDump(kIntroductory)) {
2412 string renormalize_net_name = "Renomalize_" + std::to_string(rank_id_iter) + ".ir";
2413 DumpIR(renormalize_net_name, tmp_func_graph);
2414 }
2415 #endif
2416
2417 parallel::HandleGroupInfo();
2418 string net_save_name = "split_net" + std::to_string(rank_id_iter);
2419 MindIRExporter mindir_exporter;
2420 res = mindir_exporter.ExportProto(tmp_func_graph, net_save_name, nullptr);
2421 if (!res) {
2422 MS_LOG(ERROR) << "Export MindIR file failed failed. Please check.";
2423 return nullptr;
2424 }
2425 }
2426
2427 return func_graph;
2428 }
2429
DynamicObfuscateMindIR(const std::string & file_name,float obf_ratio,int branch_control_input,char * dec_key,const size_t key_len,const std::string & dec_mode)2430 FuncGraphPtr DynamicObfuscateMindIR(const std::string &file_name, float obf_ratio, int branch_control_input,
2431 char *dec_key, const size_t key_len, const std::string &dec_mode) {
2432 if (branch_control_input == 0) {
2433 (void)mindspore::kernel::CustomizedOpaquePredicate::GetInstance().set_func_names();
2434 MS_LOG(DEBUG) << "[DynamicObfuscateMindIR] set function names finished.";
2435 }
2436 mindspore::DynamicObfuscator dynamic_obfuscator(obf_ratio, branch_control_input);
2437 MindIRLoader mindir_loader(false, reinterpret_cast<unsigned char *>(dec_key), key_len, dec_mode, false);
2438 FuncGraphPtr func_graph = mindir_loader.LoadMindIR(file_name);
2439 ModifyGraphs(func_graph);
2440 auto manager = func_graph->manager();
2441 if (manager == nullptr) {
2442 manager = MakeManager();
2443 manager->AddFuncGraph(func_graph, true);
2444 }
2445 InferFuncGraphLoaded(func_graph);
2446 if (func_graph == nullptr) {
2447 MS_LOG(EXCEPTION) << "[DynamicObfuscateMindIR] load mindir failed, please check the mindir file.";
2448 return nullptr;
2449 }
2450 mindspore::FuncGraphPtr obfuscated_graph = dynamic_obfuscator.ObfuscateMindIR(func_graph);
2451 if (obfuscated_graph == nullptr) {
2452 MS_LOG(ERROR) << "[DynamicObfuscateMindIR] obfuscate model failed.";
2453 return nullptr;
2454 }
2455 return obfuscated_graph;
2456 }
2457
CloseTsd(bool force)2458 void CloseTsd(bool force) {
2459 #ifdef WITH_BACKEND
2460 auto context_ptr = MsContext::GetInstance();
2461 MS_EXCEPTION_IF_NULL(context_ptr);
2462 if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice) {
2463 const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
2464 {kAscendDevice, context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
2465 MS_EXCEPTION_IF_NULL(device_context);
2466 MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
2467 (void)device_context->GetDeprecatedInterface()->CloseTsd(context_ptr, force);
2468 }
2469 #endif
2470 }
2471
InitPipeline()2472 void InitPipeline() {
2473 // set python env flag
2474 RecordInitStatus();
2475 mindspore::python_adapter::set_python_env_flag(true);
2476 auto ms_context = MsContext::GetInstance();
2477 MS_EXCEPTION_IF_NULL(ms_context);
2478 CompileConfigManager::GetInstance().CollectCompileConfig();
2479 #ifdef WITH_BACKEND
2480 auto backend = ms_context->backend_policy();
2481 auto device_name = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
2482 if (backend == "ge") {
2483 const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
2484 {device_name, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
2485 MS_EXCEPTION_IF_NULL(device_context);
2486 device_context->Initialize();
2487 }
2488 if (!common::UseDynamicCluster()) {
2489 if (device_name == kAscendDevice) {
2490 const auto &device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
2491 {device_name, ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID)});
2492 MS_EXCEPTION_IF_NULL(device_context);
2493 MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
2494 if (!device_context->GetDeprecatedInterface()->OpenTsd(ms_context)) {
2495 MS_LOG(EXCEPTION) << "Open tsd failed";
2496 }
2497 }
2498 }
2499 #endif
2500 }
2501
FinalizeBackend()2502 void FinalizeBackend() { CloseTsd(); }
2503
MemoryRecycle()2504 void MemoryRecycle() {
2505 #ifdef ENABLE_DUMP_IR
2506 mindspore::RDR::ResetRecorder();
2507 #endif
2508 ReclaimOptimizer();
2509 session::ExecutorManager::Instance().ClearDoneTasks();
2510 ad::g_k_prims.clear();
2511 ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
2512 abstract::AnalysisResultCacheMgr::GetInstance().Clear();
2513 abstract::AnalysisContext::ClearContext();
2514 kArgsCache.clear();
2515 kCellArgsMap.clear();
2516 // clean static variable to prevent from crash. As static variable is released after
2517 // Python threads is released.
2518 parse::data_converter::ClearObjectCache();
2519 parse::Parser::CleanParserResource();
2520 trace::ClearTraceStack();
2521 pynative::PyNativeExecutor::GetInstance()->ClearRes();
2522 ConfigManager::GetInstance().ResetConfig();
2523 ScopeManager::GetInstance().ClearScope();
2524 FuncGraphLoopBreaker::Inst().CleanMetaFuncGraphs();
2525 FuncGraphLoopBreaker::Inst().BreakLoop();
2526 }
2527
BindDeviceCtx()2528 void BindDeviceCtx() { device::DeviceContextManager::GetInstance().BindDeviceCtx(); }
2529
ClearResPart1()2530 void ClearResPart1() {
2531 pynative::PyNativeExecutor::GetInstance()->WorkerJoin();
2532 runtime::OpExecutor::GetInstance().WorkerJoin();
2533 // When the python process exits, the kernels on the device may not have finished executing.
2534 device::KernelRuntimeManager::Instance().WaitTaskFinishOnDevice();
2535 device::DeviceContextManager::GetInstance().WaitTaskFinishOnDevice();
2536
2537 RecordExitStatus();
2538 #ifdef ENABLE_DUMP_IR
2539 mindspore::RDR::Snapshot();
2540 mindspore::RDR::ResetRecorder();
2541 #endif
2542 runtime::GraphScheduler::GetInstance().Clear();
2543 runtime::ProfilerAnalyzer::GetInstance().Clear();
2544
2545 auto ms_context = MsContext::GetInstance();
2546 MS_EXCEPTION_IF_NULL(ms_context);
2547 if (ms_context->backend_policy() != "ge") {
2548 // clear runtime resource before destroy hccl comm
2549 MS_LOG(INFO) << "Start clear kernel runtime...";
2550 device::KernelRuntimeManager::Instance().ClearRuntimeResource();
2551 MS_LOG(INFO) << "End clear kernel runtime.";
2552 }
2553
2554 MS_LOG(INFO) << "Start Finalize StreamSynchronizer...";
2555 device::StreamSynchronizer::GetInstance()->Finalize();
2556 MS_LOG(INFO) << "End Finalize StreamSynchronizer...";
2557
2558 PrimitivePy::ClearHookRes();
2559 ad::g_k_prims.clear();
2560 ad::PrimBpropOptimizer::GetPrimBpropOptimizerInst().Clear();
2561
2562 abstract::ClearPrimEvaluatorMap();
2563 pipeline::GetMethodMap().clear();
2564 pipeline::GetAttrMap().clear();
2565 pipeline::GraphExecutorPy::ClearRes();
2566 pipeline::ReclaimOptimizer();
2567 }
2568
ClearResPart2()2569 void ClearResPart2() {
2570 MS_LOG(INFO) << "Start clear PyNativeExecutor...";
2571 pynative::PyNativeExecutor::GetInstance()->ClearRes();
2572 MS_LOG(INFO) << "End clear PyNativeExecutor.";
2573
2574 #ifdef WITH_BACKEND
2575 auto ms_context = MsContext::GetInstance();
2576 MS_EXCEPTION_IF_NULL(ms_context);
2577 if (ms_context->backend_policy() == "ge") {
2578 auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
2579 DeviceContext *device_context =
2580 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({kAscendDevice, device_id});
2581 MS_EXCEPTION_IF_NULL(device_context);
2582 MS_EXCEPTION_IF_NULL(device_context->GetDeprecatedInterface());
2583 device_context->GetDeprecatedInterface()->ClearGraphWrapper();
2584 device_context->GetDeprecatedInterface()->ClearOpAdapterMap();
2585 // unregister external allocator, before clear stream and graphrunner
2586 device_context->GetDeprecatedInterface()->UnregisterExternalAllocator();
2587 // clear runtime resource after clear graph when ge
2588 MS_LOG(INFO) << "Start clear kernel runtime...";
2589 device::KernelRuntimeManager::Instance().ClearRuntimeResource();
2590 MS_LOG(INFO) << "End clear kernel runtime.";
2591 } else {
2592 MS_LOG(INFO) << "Start clear ConfigManager...";
2593 ConfigManager::GetInstance().ResetIterNum();
2594 MS_LOG(INFO) << "End clear ConfigManager.";
2595 }
2596 #else
2597 MS_LOG(INFO) << "Start clear ConfigManager...";
2598 ConfigManager::GetInstance().ResetIterNum();
2599 MS_LOG(INFO) << "End clear ConfigManager.";
2600 #endif
2601
2602 session::ExecutorManager::Instance().Clear();
2603 // for GE, HcclCommDestroy should after RemoveGraph in ClearGraphWrapper
2604 (void)distributed::collective::CollectiveManager::instance()->Finalize();
2605
2606 MS_LOG(INFO) << "Start clear device context...";
2607 device::DeviceContextManager::GetInstance().ClearDeviceContexts();
2608 MS_LOG(INFO) << "End clear device context.";
2609
2610 MS_LOG(INFO) << "Start clear AnalysisResultCacheMgr...";
2611 abstract::AnalysisResultCacheMgr::GetInstance().Clear();
2612 MS_LOG(INFO) << "End clear AnalysisResultCacheMgr.";
2613
2614 MS_LOG(INFO) << "Start clear AnalysisContext...";
2615 abstract::AnalysisContext::ClearContext();
2616 MS_LOG(INFO) << "End clear AnalysisContext...";
2617
2618 MS_LOG(INFO) << "Start clear AnalysisSchedule...";
2619 abstract::AnalysisSchedule::GetInstance().Stop();
2620 MS_LOG(INFO) << "End clear AnalysisSchedule...";
2621 #ifdef ENABLE_DEBUGGER
2622 auto debugger = Debugger::GetInstance();
2623 MS_EXCEPTION_IF_NULL(debugger);
2624 debugger->Reset();
2625 #endif
2626 kArgsCache.clear();
2627 kCellArgsMap.clear();
2628 }
2629
ClearResPart3()2630 void ClearResPart3() {
2631 // clean static variable to prevent from crash. As static variable is released after
2632 // Python threads is released.
2633 MS_LOG(INFO) << "Start clear ClearObjectCache...";
2634 parse::data_converter::ClearObjectCache();
2635 MS_LOG(INFO) << "End clear ClearObjectCache...";
2636
2637 MS_LOG(INFO) << "Start clear Parser...";
2638 parse::Parser::CleanParserResource();
2639 MS_LOG(INFO) << "End clear Parser...";
2640
2641 MS_LOG(INFO) << "Start ClearTraceStack...";
2642 trace::ClearTraceStack();
2643 MS_LOG(INFO) << "End ClearTraceStack...";
2644
2645 MS_LOG(INFO) << "Start clear InterpretNodeRecorder...";
2646 InterpretNodeRecorder::GetInstance().Clear();
2647 MS_LOG(INFO) << "End clear InterpretNodeRecorder...";
2648
2649 MS_LOG(INFO) << "Start clear parallel::entire_costgraph...";
2650 parallel::entire_costgraph.reset();
2651 MS_LOG(INFO) << "End clear parallel::entire_costgraph...";
2652
2653 MS_LOG(INFO) << "Start clear ProtobufLibrary...";
2654 google::protobuf::ShutdownProtobufLibrary();
2655 MS_LOG(INFO) << "End clear ProtobufLibrary...";
2656 // ResetPythonScope after all py::object is freed.
2657 MS_LOG(INFO) << "Start clear python_adapter...";
2658 python_adapter::ResetPythonScope();
2659 MS_LOG(INFO) << "End clear python_adapter.";
2660 }
2661
ClearSingleton()2662 void ClearSingleton() {
2663 MS_LOG(INFO) << "Start clear singleton...";
2664 profiler::Profiler::Clear();
2665 #ifdef ENABLE_AKG
2666 kernel::GraphKernelBuildManager::Instance().Clear();
2667 #endif
2668 somas::SomasManager::Instance().Clear();
2669 GraphKernelInfoManager::Instance().Clear();
2670 device::DataQueueMgr::GetInstance().Clear();
2671 session::SessionFactory::Get().Clear();
2672 device::KernelRuntimeManager::Instance().Clear();
2673 OpPrimPyRegister::GetInstance().Clear();
2674 #ifndef ENABLE_SECURITY
2675 DumpJsonParser::Finalize();
2676 AclDumpJsonWriter::Finalize();
2677 #endif
2678 CommManager::Clear();
2679 expander::ClearAllCache();
2680 MS_LOG(INFO) << "End clear singleton.";
2681 }
2682
ClearResAtexit()2683 void ClearResAtexit() {
2684 MS_LOG(INFO) << "Pipeline clear all resource";
2685 try {
2686 MsException::Instance().CheckException();
2687 } catch (const std::exception &e) {
2688 MS_LOG(ERROR) << "Check exception before process exit: " << e.what();
2689 }
2690 ClearResPart1();
2691 ClearResPart2();
2692
2693 mindspore::trans::FormatHelper::GetInstance().Clear();
2694 ClearResPart3();
2695 ClearSingleton();
2696 MS_LOG(INFO) << "Start unload dynamic lib...";
2697 device::DeviceContextManager::GetInstance().UnloadPlugin();
2698 MS_LOG(INFO) << "End unload dynamic lib...";
2699 }
2700
PyEncrypt(char * plain_data,size_t plain_len,char * key,size_t key_len,const std::string & enc_mode)2701 py::bytes PyEncrypt(char *plain_data, size_t plain_len, char *key, size_t key_len, const std::string &enc_mode) {
2702 size_t encrypt_len;
2703 auto encrypt_data = mindspore::Encrypt(&encrypt_len, reinterpret_cast<Byte *>(plain_data), plain_len,
2704 reinterpret_cast<Byte *>(key), key_len, enc_mode);
2705 if (encrypt_data == nullptr) {
2706 MS_EXCEPTION(ValueError) << "Encrypt failed";
2707 }
2708 auto py_encrypt_data = py::bytes(reinterpret_cast<char *>(encrypt_data.get()), encrypt_len);
2709 return py_encrypt_data;
2710 }
2711
PyDecrypt(const std::string & encrypt_data_path,char * key,size_t key_len,const std::string & dec_mode)2712 py::bytes PyDecrypt(const std::string &encrypt_data_path, char *key, size_t key_len, const std::string &dec_mode) {
2713 size_t decrypt_len;
2714 auto decrypt_data =
2715 mindspore::Decrypt(&decrypt_len, encrypt_data_path, reinterpret_cast<Byte *>(key), key_len, dec_mode);
2716 if (decrypt_data == nullptr) {
2717 MS_LOG(ERROR) << "Decrypt failed";
2718 return py::none();
2719 }
2720 auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len);
2721 return py_decrypt_data;
2722 }
2723
PyDecryptData(char * model_data,size_t data_size,char * key,size_t key_len,const std::string & dec_mode)2724 py::bytes PyDecryptData(char *model_data, size_t data_size, char *key, size_t key_len, const std::string &dec_mode) {
2725 size_t decrypt_len;
2726 auto decrypt_data = mindspore::Decrypt(&decrypt_len, reinterpret_cast<Byte *>(model_data), data_size,
2727 reinterpret_cast<Byte *>(key), key_len, dec_mode);
2728 if (decrypt_data == nullptr) {
2729 MS_LOG(ERROR) << "Decrypt failed";
2730 return py::none();
2731 }
2732 auto py_decrypt_data = py::bytes(reinterpret_cast<char *>(decrypt_data.get()), decrypt_len);
2733 return py_decrypt_data;
2734 }
2735
PyIsCipherFile(const std::string & file_path)2736 bool PyIsCipherFile(const std::string &file_path) { return mindspore::IsCipherFile(file_path); }
2737
FinalizeCluster()2738 void FinalizeCluster() {
2739 #if defined(__linux__) && defined(WITH_BACKEND)
2740 if (distributed::cluster::ClusterContext::instance()->initialized()) {
2741 if (!distributed::cluster_exit_with_exception()) {
2742 MS_LOG(INFO) << "Start finalize the cluster instance.";
2743 // Finalize MindSpore cluster only when this process exits without any exception.
2744 (void)distributed::cluster::ClusterContext::instance()->Finalize(UINT32_MAX);
2745 MS_LOG(INFO) << "End finalize the cluster instance.";
2746 }
2747 }
2748 #endif
2749 }
2750
SwapCache(const tensor::TensorPtr & host,const tensor::TensorPtr & device,const tensor::TensorPtr & block_mapping,const bool & is_device_to_host)2751 void SwapCache(const tensor::TensorPtr &host, const tensor::TensorPtr &device, const tensor::TensorPtr &block_mapping,
2752 const bool &is_device_to_host) {
2753 auto block_mapping_shape = block_mapping->shape();
2754 if (block_mapping_shape.size() != 2) {
2755 MS_LOG_EXCEPTION << "The shape size of Cache input mapping tensor should be 2, but got: "
2756 << block_mapping_shape.size();
2757 }
2758 if (block_mapping_shape[1] != 2) {
2759 MS_LOG_EXCEPTION << "The second dim of CacheKernel input mapping tensor should be 2, but got: "
2760 << block_mapping_shape[0];
2761 }
2762
2763 auto in_shape = device->shape();
2764 auto type_byte = GetTypeByte(TypeIdToType(host->data_type()));
2765 size_t block_size_in_bytes = LongToSize(
2766 std::accumulate(in_shape.begin() + 1, in_shape.end(), SizeToLong(type_byte), std::multiplies<int64_t>()));
2767
2768 uint8_t *host_ptr = reinterpret_cast<uint8_t *>(host->data_c());
2769 MS_EXCEPTION_IF_NULL(host_ptr);
2770 auto device_addr = std::dynamic_pointer_cast<device::DeviceAddress>(device->device_address());
2771 MS_EXCEPTION_IF_NULL(device_addr);
2772 uint8_t *device_ptr = reinterpret_cast<uint8_t *>(const_cast<void *>(device_addr->GetPtr()));
2773 MS_EXCEPTION_IF_NULL(device_ptr);
2774
2775 auto block_mapping_data = reinterpret_cast<int64_t *>(block_mapping->data_c());
2776 for (int64_t i = 0; i < block_mapping_shape[0]; i++) {
2777 int64_t src_block_num = block_mapping_data[2 * i];
2778 int64_t dst_block_num = block_mapping_data[2 * i + 1];
2779 size_t src_block_offset = LongToSize(src_block_num) * block_size_in_bytes;
2780 size_t dst_block_offset = LongToSize(dst_block_num) * block_size_in_bytes;
2781
2782 if (is_device_to_host) {
2783 device_addr->CopyDeviceToHost(host_ptr + dst_block_offset, device_ptr + src_block_offset, block_size_in_bytes);
2784 } else {
2785 device_addr->CopyHostToDevice(device_ptr + dst_block_offset, host_ptr + src_block_offset, block_size_in_bytes);
2786 }
2787 }
2788 }
2789 } // namespace pipeline
2790 } // namespace mindspore
2791