1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <pybind11/operators.h>
18 #include <stack>
19 #include "kernel/oplib/oplib.h"
20 #include "pipeline/jit/ps/pipeline.h"
21 #include "frontend/operator/composite/composite.h"
22 #include "pipeline/pynative/pynative_execute.h"
23 #include "utils/symbolic.h"
24 #include "include/common/pybind_api/api_register.h"
25 #include "include/common/utils/python_adapter.h"
26 #ifndef ENABLE_SECURITY
27 #include "include/common/utils/summary/event_writer.h"
28 #endif
29 #include "include/common/utils/config_manager.h"
30 #include "include/common/utils/mpi/mpi_config.h"
31 #include "utils/ms_utils.h"
32 #include "utils/ms_context.h"
33 #include "include/common/utils/parallel_context.h"
34 #include "include/common/utils/offload_context.h"
35 #include "frontend/parallel/costmodel_context.h"
36 #if ((defined ENABLE_CPU) && (!defined _WIN32))
37 #include "include/backend/distributed/ps/util.h"
38 #endif
39 #include "include/backend/distributed/ps/ps_context.h"
40 #include "include/backend/distributed/init.h"
41 #include "include/backend/distributed/recovery/recovery_context.h"
42 #include "include/backend/distributed/collective/collective_manager.h"
43 #if defined(__linux__) && defined(WITH_BACKEND)
44 #include "runtime/graph_scheduler/embedding_cache_scheduler.h"
45 #endif
46 #include "runtime/hardware/device_context_manager.h"
47 #include "include/backend/mem_reuse/mem_dynamic_allocator.h"
48 #include "frontend/parallel/tensor_layout/tensor_transform.h"
49
50 #include "pybind_api/gil_scoped_long_running.h"
51
52 #ifndef ENABLE_SECURITY
53 #include "include/backend/debug/profiler/profiling.h"
54 #endif
55 #include "include/common/profiler.h"
56
57 #include "pipeline/jit/pi/external.h"
58 #include "include/common/np_dtype/np_dtypes.h"
59
60 namespace py = pybind11;
61 using GraphExecutorPy = mindspore::pipeline::GraphExecutorPy;
62 using Pipeline = mindspore::pipeline::Pipeline;
63 using PrimitivePy = mindspore::PrimitivePy;
64 using MetaFuncGraph = mindspore::MetaFuncGraph;
65 #ifndef ENABLE_SECURITY
66 using EventWriter = mindspore::summary::EventWriter;
67 #endif // ENABLE_SECURITY
68 using OpLib = mindspore::kernel::OpLib;
69 using ParallelContext = mindspore::parallel::ParallelContext;
70 using CostModelContext = mindspore::parallel::CostModelContext;
71 using TensorTransform = mindspore::parallel::TensorTransform;
72 using OffloadContext = mindspore::OffloadContext;
73 using mindspore::MsCtxParam;
74 using PSContext = mindspore::ps::PSContext;
75 using CollectiveManager = mindspore::distributed::collective::CollectiveManager;
76 using RecoveryContext = mindspore::distributed::recovery::RecoveryContext;
77 using DeviceContextManager = mindspore::device::DeviceContextManager;
78 using DeviceContext = mindspore::device::DeviceContext;
79
80 constexpr int PROFILER_RECORD_STAMP = 2;
81
82 #ifndef ENABLE_SECURITY
83 namespace mindspore {
84 namespace profiler {
RegProfiler(const py::module * m)85 void RegProfiler(const py::module *m) {
86 (void)py::class_<Profiler, std::shared_ptr<Profiler>>(*m, "Profiler")
87 .def_static("get_instance", &Profiler::GetInstance, py::arg("device_name"), "Profiler get_instance.")
88 .def("init", &Profiler::Init, py::arg("profiling_path"), py::arg("device_id") = py::int_(0),
89 py::arg("profiling_options") = py::str(""), "init")
90 .def("start", &Profiler::Start, "start")
91 .def("stop", &Profiler::Stop, "stop")
92 .def("finalize", &Profiler::Finalize, "finalize")
93 .def("sync_enable", &Profiler::SyncEnable, py::arg("enable_flag"))
94 .def("data_process_enable", &Profiler::DataProcessEnable, py::arg("enable_flag"))
95 .def("step_profiling_enable", &Profiler::StepProfilingEnable, py::arg("enable_flag"),
96 "enable or disable step profiling")
97 .def("enable_op_time", &Profiler::EnableOpTime, "Enable op_time.")
98 .def("enable_profile_memory", &Profiler::EnableProfileMemory, "Enable profile_memory.");
99 }
RegProfilerManager(const py::module * m)100 void RegProfilerManager(const py::module *m) {
101 (void)py::class_<ProfilerManager, std::shared_ptr<ProfilerManager>>(*m, "ProfilerManager")
102 .def_static("get_instance", &ProfilerManager::GetInstance, "ProfilerManager get_instance.")
103 .def("dynamic_status", &ProfilerManager::GetNetDynamicShapeStatus, "dynamic_status")
104 .def("set_profile_framework", &ProfilerManager::SetProfileFramework, py::arg("profile_framework"));
105 }
106
107 // level: 0, for developer user, 1, for general user;
108 // profile_framework: 0, all host info, 1, host memory, 2, host time;
109 // start_end: 0, start flag, 1, end flag, 2, no distinguish start and end.
110 // Default parameter for host profile meaning: for developer user, collect both time and memory, record timestamp.
RegHostProfile(py::module * m)111 void RegHostProfile(py::module *m) {
112 m->def("_collect_host_info", &CollectHostInfo, py::arg("module_name"), py::arg("event"), py::arg("stage"),
113 py::arg("level") = py::int_(0), py::arg("profile_framework") = py::int_(0),
114 py::arg("start_end") = py::int_(PROFILER_RECORD_STAMP), py::arg("custom_info") = py::dict())
115 .def("get_clock_time", &GetClockTime)
116 .def("get_clock_syscnt", &GetClockSyscnt);
117 }
118
RegFrameworkProfiler(py::module * m)119 void RegFrameworkProfiler(py::module *m) {
120 m->def(
121 "_framework_profiler_step_start", []() { runtime::ProfilerAnalyzer::GetInstance().StartStep(); },
122 "Profiler step start")
123 .def(
124 "_framework_profiler_step_end", []() { runtime::ProfilerAnalyzer::GetInstance().EndStep(); }, "Profiler step end")
125 .def(
126 "_framework_profiler_clear", []() { runtime::ProfilerAnalyzer::GetInstance().Clear(); },
127 "Dump json and clear data")
128 .def("_framework_profiler_enable_mi", []() { runtime::ProfilerAnalyzer::GetInstance().EnableMiProfile(); });
129 }
130
RegFrameworkPythonProfileRecorder(py::module * m)131 void RegFrameworkPythonProfileRecorder(py::module *m) {
132 (void)py::class_<runtime::PythonProfilerRecorder, std::shared_ptr<runtime::PythonProfilerRecorder>>(
133 *m, "PythonProfilerRecorder")
134 .def(py::init<const std::string &>())
135 .def("record_start", &runtime::PythonProfilerRecorder::record_start, "record_start")
136 .def("record_end", &runtime::PythonProfilerRecorder::record_end, "record_end");
137 }
138
139 } // namespace profiler
140 } // namespace mindspore
141 #endif // ENABLE_SECURITY
142
143 namespace mindspore {
RegModule(py::module * m)144 void RegModule(py::module *m) {
145 RegTyping(m);
146 RegCNode(m);
147 RegCell(m);
148 RegMetaFuncGraph(m);
149 RegFuncGraph(m);
150 RegUpdateFuncGraphHyperParams(m);
151 RegParamInfo(m);
152 RegPrimitive(m);
153 RegPrimitiveFunction(m);
154 RegSignatureEnumRW(m);
155 RegRandomSeededGenerator(m);
156 mindspore::tensor::RegMetaTensor(m);
157 mindspore::tensor::RegCSRTensor(m);
158 mindspore::tensor::RegCOOTensor(m);
159 mindspore::tensor::RegRowTensor(m);
160 mindspore::tensor::RegMapTensor(m);
161 RegValues(m);
162 mindspore::initializer::RegRandomNormal(m);
163 RegMsContext(m);
164 RegSecurity(m);
165 RegForkUtils(m);
166 RegNumpyTypes(m);
167 mindspore::hal::RegStream(m);
168 mindspore::hal::RegEvent(m);
169 mindspore::hal::RegMemory(m);
170 mindspore::pynative::RegPyNativeExecutor(m);
171 mindspore::pynative::RegisterPyBoostFunction(m);
172 mindspore::pijit::RegPIJitInterface(m);
173 mindspore::prim::RegCompositeOpsGroup(m);
174 #ifndef ENABLE_SECURITY
175 mindspore::profiler::RegProfilerManager(m);
176 mindspore::profiler::RegProfiler(m);
177 mindspore::profiler::RegHostProfile(m);
178 mindspore::profiler::RegFrameworkProfiler(m);
179 mindspore::profiler::RegFrameworkPythonProfileRecorder(m);
180 #endif
181 #ifdef _MSC_VER
182 mindspore::abstract::RegPrimitiveFrontEval();
183 #endif
184 mindspore::ops::RegOpEnum(m);
185 }
186
RegModuleHelper(py::module * m)187 void RegModuleHelper(py::module *m) {
188 static std::once_flag onlyCalledOnce;
189 std::call_once(onlyCalledOnce, RegModule, m);
190 }
191 } // namespace mindspore
192
193 // Interface with python
PYBIND11_MODULE(_c_expression,m)194 PYBIND11_MODULE(_c_expression, m) {
195 // The OMP_NUM_THREADS has no effect when set in backend, so set it here in advance.
196 mindspore::common::SetOMPThreadNum();
197
198 m.doc() = "MindSpore c plugin";
199
200 mindspore::RegModuleHelper(&m);
201 mindspore::ScopedLongRunning::SetHook(std::make_unique<mindspore::GilScopedLongRunningHook>());
202
203 // Class Pipeline interface
204 MS_LOG(INFO) << "Start GraphExecutorPy...";
205 (void)py::class_<GraphExecutorPy, std::shared_ptr<GraphExecutorPy>>(m, "GraphExecutor_")
206 .def_static("get_instance", &GraphExecutorPy::GetInstance, "Executor get_instance.")
207 .def("__call__", &GraphExecutorPy::Run, py::arg("args"), py::arg("phase") = py::str(""), "Executor run function.")
208 .def("del_net_res", &GraphExecutorPy::DelNetRes, py::arg("obj"), py::arg("network_id") = py::set(),
209 "Delete network resource.")
210 .def("get_func_graph", &GraphExecutorPy::GetFuncGraph, py::arg("phase") = py::str(""), "Get graph pointer.")
211 .def("get_func_graph_proto", &GraphExecutorPy::GetFuncGraphProto, py::arg("phase") = py::str(""),
212 py::arg("type") = py::str("onnx_ir"), py::arg("incremental") = py::bool_(false),
213 "Get graph proto string by specifying ir type.")
214 .def("get_obfuscate_func_graph_proto", &GraphExecutorPy::GetObfuscateFuncGraphProto, py::arg("phase") = py::str(""),
215 py::arg("incremental") = py::bool_(false), py::arg("obf_ratio") = py::float_(1.0),
216 py::arg("branch_control_input") = py::int_(0), "Get graph proto of dynamic-obfuscated model.")
217 .def("get_params", &GraphExecutorPy::GetParams, py::arg("phase") = py::str(""), "Get Parameters from graph")
218 .def("get_random_status", &GraphExecutorPy::GetRandomStatus, py::arg("phase") = py::str(""),
219 "Get random status from graph")
220 .def("compile", &GraphExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("kwargs"),
221 py::arg("phase") = py::str(""), py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
222 .def("updata_param_node_default_input", &GraphExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"),
223 py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.")
224 .def("get_parameter_layout", &GraphExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"),
225 "Get Parameter Tensor Layout Dictionary.")
226 .def("flops_collection", &GraphExecutorPy::FlopsCollection, py::arg("phase") = py::str("train"),
227 "Get model flops information.")
228 .def("get_parallel_graph_info", &GraphExecutorPy::GetParallelGraphInfo, py::arg("phase") = py::str("train"),
229 "Get graph info in step_parallel stage.")
230 .def("get_parallel_parameter_name_list", &GraphExecutorPy::GetParallelParameterNameList,
231 py::arg("phase") = py::str("train"), "Get Parallel Parameter Name List.")
232 .def("get_strategy", &GraphExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"),
233 "Get CNode Strategy Dictionary.")
234 .def("get_num_parallel_ops", &GraphExecutorPy::GetNumOpsInfo, py::arg("phase") = py::str("train"),
235 "Get the number of parallel operators.")
236 .def("get_allreduce_fusion", &GraphExecutorPy::GetAllreduceFusion, py::arg("phase") = py::str("train"),
237 "Get Allreduce Fusion Dictionary.")
238 .def("build_data_graph", &GraphExecutorPy::BuildGraph, py::arg("build_params"), py::arg("phase") = py::str("train"),
239 "Build data graph.")
240 .def("export_graph", &GraphExecutorPy::ExportGraph, py::arg("file_name"), py::arg("phase"),
241 py::arg("encrypt") = py::none(), py::arg("key") = nullptr, "Export Graph.")
242 .def("has_compiled", &GraphExecutorPy::HasCompiled, py::arg("phase") = py::str(""), "Get if cell compiled.")
243 .def("set_py_exe_path", &GraphExecutorPy::PyExePath, py::arg("py_exe_path") = py::str(""),
244 "Set python executable path.")
245 .def("set_kernel_build_server_dir", &GraphExecutorPy::KernelBuildServerDir,
246 py::arg("kernel_build_server_dir") = py::str(""), "Set kernel build server directory path.")
247 .def("set_queue_name", &GraphExecutorPy::set_queue_name, py::arg("queue_name") = py::str(""),
248 "Set queue name for the graph loaded from compile cache.")
249 .def("get_queue_name", &GraphExecutorPy::get_queue_name,
250 "Get cached queue name for the graph loaded from compile cache.")
251 .def("set_enable_tuple_broaden", &GraphExecutorPy::set_enable_tuple_broaden,
252 py::arg("enable_tuple_broaden") = py::bool_(false), "Set tuple broaden enable.")
253 .def("set_compile_cache_dep_files", &GraphExecutorPy::set_compile_cache_dep_files,
254 py::arg("compile_cache_dep_files") = py::list(), "Set the compilation cache dependent files.")
255 .def("set_weights_values", &GraphExecutorPy::set_weights_values, py::arg("weights") = py::dict(),
256 "Set values of weights.")
257 .def("get_optimize_graph_proto", &GraphExecutorPy::GetOptimizeGraphProto, py::arg("phase") = py::str(""),
258 "Get the optimize graph proto string.")
259 .def("set_jit_config", &GraphExecutorPy::SetJitConfig, py::arg("jit_config") = py::dict(), "Set the jit config.")
260 .def("generate_arguments_key", &GraphExecutorPy::GenerateArgumentsKey, "Generate unique key of argument.")
261 .def("check_argument_consistency", &GraphExecutorPy::CheckArgumentsConsistency, "Check equal of arguments.")
262 .def("clear_compile_arguments_resource", &GraphExecutorPy::ClearCompileArgumentsResource,
263 "Clear resource when phase cached.")
264 .def("inc_graph_cell_count", &GraphExecutorPy::IncGraphCellCount, "Increase the count of GraphCell instance.")
265 .def("dec_graph_cell_count", &GraphExecutorPy::DecGraphCellCount, "Decrease the count of GraphCell instance.");
266
267 (void)m.def("reset_op_id", &mindspore::pipeline::ResetOpId, "Reset Operator Id");
268 (void)m.def("reset_op_id_with_offset", &mindspore::pipeline::ResetOpIdWithOffset, "Reset Operator Id With Offset");
269 (void)m.def("init_hccl", &mindspore::pipeline::InitHccl, "Init Hccl");
270 (void)m.def("finalize_hccl", &mindspore::pipeline::FinalizeHccl, "Finalize Hccl");
271 (void)m.def("get_hccl_rank_id", &mindspore::pipeline::GetHcclRankId, "Get Hccl Rank Id");
272 (void)m.def("get_hccl_rank_size", &mindspore::pipeline::GetHcclRankSize, "Get Hccl Rank Size");
273 (void)m.def("verify_inputs_signature", &mindspore::pipeline::VerifyInputSignature, "Verify input signature.");
274 (void)m.def("init_exec_dataset", &mindspore::pipeline::InitExecDataset, py::arg("queue_name"), py::arg("size"),
275 py::arg("batch_size"), py::arg("types"), py::arg("shapes"), py::arg("input_indexs"),
276 py::arg("phase") = py::str("dataset"), py::arg("need_run") = py::bool_(true), "Init and exec dataset.");
277 (void)m.def("_set_dataset_mode_config", &mindspore::ConfigManager::SetDatasetModeConfig, "API for set dataset mode.");
278 (void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline.");
279 (void)m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), py::arg("dec_key") = nullptr,
280 py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"),
281 py::arg("decrypt") = py::none(), py::arg("obfuscated") = py::bool_(false), "Load model as Graph.");
282 (void)m.def("split_mindir", &mindspore::pipeline::SplitMindIR, py::arg("file_name"),
283 "Split single mindir to distributed mindir");
284 (void)m.def("split_dynamic_mindir", &mindspore::pipeline::SplitDynamicMindIR, py::arg("file_name"),
285 py::arg("device_num") = py::int_(8), py::arg("rank_id") = py::int_(0), py::arg("sapp") = py::bool_(true),
286 "Split single mindir to distributed mindir");
287 (void)m.def("dynamic_obfuscate_mindir", &mindspore::pipeline::DynamicObfuscateMindIR, py::arg("file_name"),
288 py::arg("obf_ratio"), py::arg("branch_control_input") = py::int_(0), py::arg("dec_key") = nullptr,
289 py::arg("key_len") = py::int_(0), py::arg("dec_mode") = py::str("AES-GCM"),
290 "Obfuscate a mindir model by dynamic obfuscation.");
291 (void)m.def("init_cluster", &mindspore::distributed::Initialize, "Init Cluster");
292 (void)m.def("set_cluster_exit_with_exception", &mindspore::distributed::set_cluster_exit_with_exception,
293 "Set this process exits with exception.");
294
295 (void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
296 .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")
297 .def("get_enable_mpi", &mindspore::MpiConfig::enable_mpi, "Get whether enable mpi.")
298 .def("set_enable_mpi", &mindspore::MpiConfig::set_enable_mpi, "Set whether to enable mpi.");
299
300 (void)py::class_<TensorTransform, std::shared_ptr<TensorTransform>>(m, "TensorTransform")
301 .def_static("get_instance", &TensorTransform::GetInstance, "Get tensor_transform instance.")
302 .def("transform_tensor_sharding", &TensorTransform::TransformOperators, "Transform the tensor sharding.");
303 MS_LOG(INFO) << "Start ParallelContext...";
304 (void)py::class_<ParallelContext, std::shared_ptr<ParallelContext>>(m, "AutoParallelContext")
305 .def_static("get_instance", &ParallelContext::GetInstance, "Get auto parallel context instance.")
306 .def("get_device_num", &ParallelContext::device_num, "Get device num.")
307 .def("set_hccl_test_avaible", &ParallelContext::set_hccl_test_available, "Set hccl test available.")
308 .def("set_device_num", &ParallelContext::set_device_num, "Set device num.")
309 .def("get_device_num_is_set", &ParallelContext::device_num_is_set, "Get device num is set.")
310 .def("set_fusion_threshold_mb", &ParallelContext::set_fusion_threshold_mb, "Set fusion threshold.")
311 .def("set_allgather_fusion_threshold_mb", &ParallelContext::set_allgather_fusion_threshold_mb,
312 "Set allgather fusion threshold.")
313 .def("set_reducescatter_fusion_threshold_mb", &ParallelContext::set_reducescatter_fusion_threshold_mb,
314 "Set reducescatter fusion threshold.")
315 .def("fusion_threshold_mb", &ParallelContext::fusion_threshold_mb, "Get allreduce fusion threshold.")
316 .def("allgather_fusion_threshold_mb", &ParallelContext::allgather_fusion_threshold_mb,
317 "Get allgather fusion threshold.")
318 .def("reducescatter_fusion_threshold_mb", &ParallelContext::reducescatter_fusion_threshold_mb,
319 "Get reduce_scatter fusion threshold.")
320 .def("set_fusion_mode", &ParallelContext::set_fusion_mode, "Get fusion mode.")
321 .def("get_fusion_mode", &ParallelContext::get_fusion_mode, "Get fusion mode.")
322 .def("get_global_rank", &ParallelContext::global_rank, "Get global rank.")
323 .def("set_global_rank", &ParallelContext::set_global_rank, "Set global rank.")
324 .def("get_grad_accumulation_shard", &ParallelContext::grad_accumulation_shard, "Get grad_accumulation_shard.")
325 .def("set_grad_accumulation_shard", &ParallelContext::set_grad_accumulation_shard, "Set grad_accumulation_shard.")
326 .def("get_parallel_optimizer_threshold", &ParallelContext::get_parallel_optimizer_threshold, "Get opt threshold.")
327 .def("set_parallel_optimizer_threshold", &ParallelContext::set_parallel_optimizer_threshold, "Set opt threshold.")
328 .def("get_global_rank_is_set", &ParallelContext::global_rank_is_set, "Get global rank is set.")
329 .def("get_gradients_mean", &ParallelContext::gradients_mean, "Get mirror mean.")
330 .def("set_gradients_mean", &ParallelContext::set_gradients_mean, "Set mirror mean.")
331 .def("get_gradient_fp32_sync", &ParallelContext::gradient_fp32_sync, "Get cast before mirror.")
332 .def("set_gradient_fp32_sync", &ParallelContext::set_gradient_fp32_sync, "Set cast before mirror.")
333 .def("get_loss_repeated_mean", &ParallelContext::loss_repeated_mean, "Get loss repeated mean.")
334 .def("set_loss_repeated_mean", &ParallelContext::set_loss_repeated_mean, "Set loss repeated mean.")
335 .def("get_parallel_mode", &ParallelContext::parallel_mode, "Get parallel mode.")
336 .def("set_parallel_mode", &ParallelContext::set_parallel_mode, "Set parallel mode.")
337 .def("get_grad_accumulation_step", &ParallelContext::grad_accumulation_step, "Get grad accumulation step.")
338 .def("set_grad_accumulation_step", &ParallelContext::set_grad_accumulation_step, "Set grad accumulation step.")
339 .def("get_strategy_search_mode", &ParallelContext::strategy_search_mode, "Get strategy search mode.")
340 .def("set_strategy_search_mode", &ParallelContext::set_strategy_search_mode, "Set strategy search mode.")
341 .def("set_all_reduce_fusion_split_indices", &ParallelContext::SetAllReduceFusionSplitIndices,
342 "Set all reduce fusion split indices.")
343 .def("get_all_reduce_fusion_split_indices", &ParallelContext::GetAllReduceFusionSplitIndices,
344 "Get all reduce fusion split indices.")
345 .def("set_all_reduce_fusion_split_sizes", &ParallelContext::SetAllReduceFusionSplitSizes,
346 "Set all reduce fusion split sizes.")
347 .def("get_all_reduce_fusion_split_sizes", &ParallelContext::GetAllReduceFusionSplitSizes,
348 "Get all reduce fusion split sizes.")
349 .def("set_enable_all_reduce_fusion", &ParallelContext::set_enable_all_reduce_fusion,
350 "Set enable/disable all reduce fusion.")
351 .def("get_enable_all_reduce_fusion", &ParallelContext::enable_all_reduce_fusion,
352 "Get enable/disable all reduce fusion.")
353 .def("set_enable_all_gather_fusion", &ParallelContext::set_enable_all_gather_fusion,
354 "Set enable/disable all gather fusion.")
355 .def("get_enable_all_gather_fusion", &ParallelContext::enable_all_gather_fusion,
356 "Get enable/disable all gather fusion.")
357 .def("set_enable_reduce_scatter_fusion", &ParallelContext::set_enable_reduce_scatter_fusion,
358 "Set enable/disable reduce scatter fusion.")
359 .def("get_enable_reduce_scatter_fusion", &ParallelContext::enable_reduce_scatter_fusion,
360 "Get enable/disable reduce scatter fusion.")
361 .def("get_parameter_broadcast", &ParallelContext::parameter_broadcast, "Get parameter broadcast.")
362 .def("get_parameter_broadcast_is_set", &ParallelContext::parameter_broadcast_is_set,
363 "Get parameter broadcast is set.")
364 .def("set_parameter_broadcast", &ParallelContext::set_parameter_broadcast, "Set parameter broadcast.")
365 .def("set_strategy_ckpt_load_file", &ParallelContext::set_strategy_ckpt_load_file,
366 "Set strategy checkpoint load file.")
367 .def("set_strategy_ckpt_save_file", &ParallelContext::set_strategy_ckpt_save_file,
368 "Set strategy checkpoint save file.")
369 .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
370 .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
371 .def("set_group_ckpt_save_file", &ParallelContext::set_group_ckpt_save_file, "Set group checkpoint save file.")
372 .def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num,
373 "Set pipeline stage split num.")
374 .def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.")
375 .def("set_auto_pipeline", &ParallelContext::set_auto_pipeline, "Set the pipeline stage number to automatic.")
376 .def("get_auto_pipeline", &ParallelContext::auto_pipeline, "Get whether the pipeline stage number is automatic.")
377 .def("set_pipeline_result_broadcast", &ParallelContext::set_pipeline_result_broadcast,
378 "Set pipeline result broadcast")
379 .def("get_pipeline_result_broadcast", &ParallelContext::pipeline_result_broadcast, "Get pipeline result broadcast")
380 .def("set_pipeline_segment_split_num", &ParallelContext::set_pipeline_segment_split_num,
381 "Set pipeline segment split num.")
382 .def("get_pipeline_segment_split_num", &ParallelContext::pipeline_segment_split_num,
383 "Get pipeline segment split num.")
384 .def("set_pipeline_interleave", &ParallelContext::set_pipeline_interleave, "Set pipeline interleave.")
385 .def("get_pipeline_interleave", &ParallelContext::pipeline_interleave, "Get pipeline interleave.")
386 .def("set_pipeline_scheduler", &ParallelContext::set_pipeline_scheduler, "Set pipeline scheduler.")
387 .def("get_pipeline_scheduler", &ParallelContext::pipeline_scheduler, "Get pipeline scheduler.")
388 .def("set_full_batch", &ParallelContext::set_full_batch, "Set whether load full batch on each device.")
389 .def("get_full_batch", &ParallelContext::full_batch, "Get whether load full batch on each device.")
390 .def("get_full_batch_is_set", &ParallelContext::full_batch_is_set, "Get whether attr full_batch is set.")
391 .def("set_dataset_strategy", &ParallelContext::set_dataset_strategy, "Set dataset sharding strategy.")
392 .def("get_dataset_strategy", &ParallelContext::dataset_strategy, "Get dataset sharding strategy.")
393 .def("set_stra_file_only_trainable_params", &ParallelContext::set_stra_file_only_trainable_params,
394 "Set strategy ckpt only save trainable params.")
395 .def("get_stra_file_only_trainable_params", &ParallelContext::stra_file_only_trainable_params,
396 "Get strategy ckpt only save trainable params.")
397 .def("set_enable_parallel_optimizer", &ParallelContext::set_enable_parallel_optimizer,
398 "Set enable/disable parallel optimizer.")
399 .def("get_enable_parallel_optimizer", &ParallelContext::enable_parallel_optimizer,
400 "Get enable/disable parallel optimizer.")
401 .def("set_force_fp32_communication", &ParallelContext::set_force_fp32_communication,
402 "Set whether to force fp32 communication value.")
403 .def("get_force_fp32_communication", &ParallelContext::force_fp32_communication,
404 "Get the switch whether to force fp32 communication value")
405 .def("get_enable_fold_pipeline", &ParallelContext::enable_fold_pipeline, "Get enable/disable fold pipeline.")
406 .def("set_communi_parallel_mode", &ParallelContext::set_communi_parallel_mode, "Set communication parallel mode.")
407 .def("get_communi_parallel_mode", &ParallelContext::communi_parallel_mode, "Get communication parallel mode.")
408 .def("set_optimizer_weight_shard_size", &ParallelContext::set_optimizer_weight_shard_size,
409 "Set opt shard group size when not fully use parallel optimizer.")
410 .def("get_optimizer_weight_shard_size", &ParallelContext::optimizer_weight_shard_size,
411 "Get opt shard group size when not fully use parallel optimizer.")
412 .def("set_optimizer_weight_shard_aggregated_save", &ParallelContext::set_optimizer_weight_shard_aggregated_save,
413 "Set whether to integrated save weight shard when enable parallel optimizer.")
414 .def("get_optimizer_weight_shard_aggregated_save", &ParallelContext::optimizer_weight_shard_aggregated_save,
415 "Get whether to integrated save weight shard when enable parallel optimizer.")
416 .def("set_enable_alltoall", &ParallelContext::set_enable_all2all, "Set the enabling AllToAll value.")
417 .def("get_enable_alltoall", &ParallelContext::enable_all2all, "Get the enabling AllToAll value.")
418 .def("set_sharding_propagation", &ParallelContext::set_sharding_propagation,
419 "Set sharding strategy propagation value.")
420 .def("get_sharding_propagation", &ParallelContext::sharding_propagation, "Get sharding strategy propagation value.")
421 .def("set_ops_strategy_json_config", &ParallelContext::set_ops_strategy_json_config,
422 "Set ops strategy save&load config.")
423 .def("reset", &ParallelContext::Reset, "Reset auto parallel context.");
424 MS_LOG(INFO) << "Start CostModelContext...";
425 (void)py::class_<CostModelContext, std::shared_ptr<CostModelContext>>(m, "CostModelContext")
426 .def_static("get_instance", &CostModelContext::GetInstance, "Get cost_model context instance.")
427 .def("set_device_memory_capacity", &CostModelContext::set_device_memory_capacity,
428 "Set the capacity of device memory.")
429 .def("get_device_memory_capacity", &CostModelContext::device_memory_capacity, "Get the capacity of device memory.")
430 .def("set_costmodel_alpha", &CostModelContext::set_costmodel_alpha,
431 "Set the parameter cost_model_alpha of the DP algorithm.")
432 .def("get_costmodel_alpha", &CostModelContext::costmodel_alpha,
433 "Get the parameter cost_model_alpha of the DP algorithm.")
434 .def("set_costmodel_beta", &CostModelContext::set_costmodel_beta,
435 "Set the parameter cost_model_beta of the DP algorithm.")
436 .def("get_costmodel_beta", &CostModelContext::costmodel_beta,
437 "Get the parameter cost_model_beta of the DP algorithm.")
438 .def("set_costmodel_gamma", &CostModelContext::set_costmodel_gamma,
439 "Set the parameter cost_model_gamma of the DP algorithm")
440 .def("get_costmodel_gamma", &CostModelContext::costmodel_gamma,
441 "Get the parameter cost_model_gamma of the DP algorithm.")
442 .def("set_costmodel_communi_threshold", &CostModelContext::set_costmodel_communi_threshold,
443 "Set the parameter cost_model_communi_threshold of the DP algorithm.")
444 .def("get_costmodel_communi_threshold", &CostModelContext::costmodel_communi_threshold,
445 "Get the parameter cost_model_communi_threshold of the DP algorithm.")
446 .def("set_costmodel_communi_const", &CostModelContext::set_costmodel_communi_const,
447 "Set the parameter cost_model_communi_const of the DP algorithm.")
448 .def("get_costmodel_communi_const", &CostModelContext::costmodel_communi_const,
449 "Get the parameter cost_model_communi_const of the DP algorithm.")
450 .def("set_costmodel_communi_bias", &CostModelContext::set_costmodel_communi_bias,
451 "Set the parameter cost_model_communi_bias of the DP algorithm.")
452 .def("get_costmodel_communi_bias", &CostModelContext::costmodel_communi_bias,
453 "Get the parameter cost_model_communi_bias of the DP algorithm.")
454 .def("set_multi_subgraphs", &CostModelContext::set_multi_subgraphs, "Set the parameter is_multi_subgraphs.")
455 .def("get_multi_subgraphs", &CostModelContext::is_multi_subgraphs, "Get the parameter is_multi_subgraphs.")
456 .def("set_run_phase", &CostModelContext::set_run_phase, "Set the flag run_phase.")
457 .def("get_run_phase", &CostModelContext::run_phase, "Get the flag run_phase.")
458 .def("set_costmodel_allreduce_fusion_algorithm", &CostModelContext::set_costmodel_allreduce_fusion_algorithm,
459 "Set the parameter gradient AllReduce fusion algorithm.")
460 .def("get_costmodel_allreduce_fusion_algorithm", &CostModelContext::costmodel_allreduce_fusion_algorithm,
461 "Get the parameter gradient AllReduce fusion algorithm.")
462 .def("set_costmodel_allreduce_fusion_times", &CostModelContext::set_costmodel_allreduce_fusion_times,
463 "Set the parameter gradient AllReduce times.")
464 .def("get_costmodel_allreduce_fusion_times", &CostModelContext::costmodel_allreduce_fusion_times,
465 "Get the parameter gradient AllReduce times.")
466 .def("set_costmodel_allreduce_fusion_tail_percent", &CostModelContext::set_costmodel_allreduce_fusion_tail_percent,
467 "Set the parameter gradient AllReduce fusion tail percent.")
468 .def("get_costmodel_allreduce_fusion_tail_percent", &CostModelContext::costmodel_allreduce_fusion_tail_percent,
469 "Get the parameter gradient AllReduce fusion tail percent.")
470 .def("set_costmodel_allreduce_fusion_tail_time", &CostModelContext::set_costmodel_allreduce_fusion_tail_time,
471 "Set the parameter gradient AllReduce fusion tail time.")
472 .def("get_costmodel_allreduce_fusion_tail_time", &CostModelContext::costmodel_allreduce_fusion_tail_time,
473 "Get the parameter gradient AllReduce fusion tail time.")
474 .def("set_costmodel_allreduce_fusion_allreduce_inherent_time",
475 &CostModelContext::set_costmodel_allreduce_fusion_allreduce_inherent_time,
476 "Set the parameter gradient AllReduce fusion allreduce inherent time.")
477 .def("get_costmodel_allreduce_fusion_allreduce_inherent_time",
478 &CostModelContext::costmodel_allreduce_fusion_allreduce_inherent_time,
479 "Get the parameter gradient AllReduce fusion allreduce inherent time.")
480 .def("set_costmodel_allreduce_fusion_allreduce_bandwidth",
481 &CostModelContext::set_costmodel_allreduce_fusion_allreduce_bandwidth,
482 "Set the parameter gradient AllReduce fusion allreduce bandwidth.")
483 .def("get_costmodel_allreduce_fusion_allreduce_bandwidth",
484 &CostModelContext::costmodel_allreduce_fusion_allreduce_bandwidth,
485 "Get the parameter gradient AllReduce fusion allreduce bandwidth.")
486 .def("set_costmodel_allreduce_fusion_computation_time_parameter",
487 &CostModelContext::set_costmodel_allreduce_fusion_computation_time_parameter,
488 "Set the parameter gradient AllReduce fusion computation time parameter.")
489 .def("get_costmodel_allreduce_fusion_computation_time_parameter",
490 &CostModelContext::costmodel_allreduce_fusion_computation_time_parameter,
491 "Get the parameter gradient AllReduce fusion computation time parameter.")
492 .def("set_tensor_slice_align_enable", &CostModelContext::set_tensor_slice_alignment_enable,
493 "Set the parameter tensor_slice_align_enable in strategy generation.")
494 .def("get_tensor_slice_align_enable", &CostModelContext::tensor_slice_alignment_enable,
495 "Get the parameter tensor_slice_align_enable in strategy generation.")
496 .def("set_tensor_slice_align_size", &CostModelContext::set_tensor_slice_alignment_size,
497 "Set the parameter tensor_slice_size in strategy generation.")
498 .def("get_tensor_slice_align_size", &CostModelContext::tensor_slice_alignment_size,
499 "Get the parameter tensor_slice_size in strategy generation.")
500 .def("set_fully_use_devices", &CostModelContext::set_fully_use_device,
501 "Set the parameter fully_use_devices in the DP algorithm.")
502 .def("get_fully_use_devices", &CostModelContext::fully_use_device,
503 "Get the parameter fully_use_devices in the DP algorithm.")
504 .def("set_elementwise_op_strategy_follow", &CostModelContext::set_elementwise_stra_follow,
505 "Set the parameter elementwise_op_strategy_follow in the DP algorithm.")
506 .def("get_elementwise_op_strategy_follow", &CostModelContext::elementwise_stra_follow,
507 "Get the parameter elementwise_op_strategy_follow in the DP algorithm.")
508 .def("set_dp_algo_enable_approxi", &CostModelContext::set_dp_algo_enable_approxi,
509 "Set the flag whether enabling approximation in the DP algorithm.")
510 .def("get_dp_algo_enable_approxi", &CostModelContext::dp_algo_enable_approxi,
511 "Get the flag whether enabling approximation in the DP algorithm.")
512 .def("set_dp_algo_approxi_epsilon", &CostModelContext::set_dp_algo_approxi_epsilon,
513 "Set the epsilon which is used in the approximation of DP algorithm.")
514 .def("get_dp_algo_approxi_epsilon", &CostModelContext::dp_algo_approxi_epsilon,
515 "Get the epsilon which is used in the approximation of DP algorithm.")
516 .def("set_rp_matmul_mem_coef", &CostModelContext::set_rp_matmul_mem_coef,
517 "Set the matmul memory coef which is used in the RP algorithm.")
518 .def("get_rp_matmul_mem_coef", &CostModelContext::rp_matmul_mem_coef,
519 "Get the matmul memory coef which is used in the RP algorithm.")
520 .def("set_dp_algo_single_loop", &CostModelContext::set_dp_algo_single_loop,
521 "Set the flag of generating a single suite of OperatorInfos in for-loop.")
522 .def("get_dp_algo_single_loop", &CostModelContext::dp_algo_single_loop,
523 "Get the flag of whether or not generating a single suite of OperatorInfos in for-loop.")
524 .def("reset_cost_model", &CostModelContext::ResetCostModel, "Reset the CostModelContext.")
525 .def("reset_algo_parameters", &CostModelContext::ResetAlgoParameters, "Reset the AlgoParameters.");
526 MS_LOG(INFO) << "Start OffloadContext...";
527 (void)py::class_<OffloadContext, std::shared_ptr<OffloadContext>>(m, "OffloadContext")
528 .def_static("get_instance", &OffloadContext::GetInstance, "Get offload context instance.")
529 .def("set_offload_param", &OffloadContext::set_offload_param, "Set the param for offload destination, cpu or disk.")
530 .def("offload_param", &OffloadContext::offload_param, "Get the param for offload destination.")
531 .def("set_offload_path", &OffloadContext::set_offload_path, "Set the path of offload.")
532 .def("offload_path", &OffloadContext::offload_path, "Get the path of offload.")
533 .def("set_offload_checkpoint", &OffloadContext::set_offload_checkpoint,
534 "Set the checkpoint for offload destination, cpu or disk.")
535 .def("offload_checkpoint", &OffloadContext::offload_checkpoint, "Get the checkpoint for offload destination.")
536 .def("set_offload_cpu_size", &OffloadContext::set_offload_cpu_size, "Set the cpu memory size for offload.")
537 .def("offload_cpu_size", &OffloadContext::offload_cpu_size, "Get the cpu memory size for offload.")
538 .def("set_offload_disk_size", &OffloadContext::set_offload_disk_size, "Set the disk size for offload.")
539 .def("offload_disk_size", &OffloadContext::offload_disk_size, "Get the disk size for offload.")
540 .def("set_enable_aio", &OffloadContext::set_enable_aio, "Set the flag of whether enabling aio.")
541 .def("enable_aio", &OffloadContext::enable_aio, "Get the flag of whether enabling aio.")
542 .def("set_aio_block_size", &OffloadContext::set_aio_block_size, "Set the size of aio block.")
543 .def("aio_block_size", &OffloadContext::aio_block_size, "Get the size of aio block.")
544 .def("set_aio_queue_depth", &OffloadContext::set_aio_queue_depth, "Set the depth of aio queue.")
545 .def("aio_queue_depth", &OffloadContext::aio_queue_depth, "Get the depth of aio queue.")
546 .def("set_enable_pinned_mem", &OffloadContext::set_enable_pinned_mem,
547 "Set the flag of whether enabling pinned memory.")
548 .def("enable_pinned_mem", &OffloadContext::enable_pinned_mem, "Get the flag of whether enabling pinned memory.")
549 .def("set_auto_offload", &OffloadContext::set_auto_offload,
550 "Set whether to automatically generate the offload strategy")
551 .def("auto_offload", &OffloadContext::auto_offload, "Get the flag of whether auto offload")
552 .def("set_host_mem_block_size", &OffloadContext::set_host_mem_block_size, "Set the block size for host memory pool")
553 .def("host_mem_block_size", &OffloadContext::host_mem_block_size, "Get the block size of host memory pool")
554 .def("set_cpu_ratio", &OffloadContext::set_cpu_ratio, "Set the cpu memory usage ratio for offload strategy")
555 .def("cpu_ratio", &OffloadContext::cpu_ratio, "Get the cpu memory usage ratio of offload strategy")
556 .def("set_hbm_ratio", &OffloadContext::set_hbm_ratio, "Set the hbm usage ratio for offload strategy")
557 .def("hbm_ratio", &OffloadContext::hbm_ratio, "Get the hbm usage ratio of offload strategy");
558
559 (void)py::module::import("atexit").attr("register")(py::cpp_function{[&]() -> void {
560 MS_LOG(INFO) << "Start register...";
561 mindspore::MsContext::GetInstance()->RegisterCheckEnv(nullptr);
562 mindspore::MsContext::GetInstance()->RegisterSetEnv(nullptr);
563 #ifndef ENABLE_SECURITY
564 MS_LOG(INFO) << "Start mindspore.profiler...";
565 try {
566 py::module profiler = py::module::import("mindspore.profiler").attr("EnvProfiler")();
567 (void)profiler.attr("analyse")();
568 } catch (const std::exception &e) {
569 MS_LOG(ERROR) << "Failed to parse profiler data." << e.what();
570 }
571 #endif
572 MS_LOG(INFO) << "Start EmbeddingCacheScheduler...";
573 #if defined(__linux__) && defined(WITH_BACKEND)
574 mindspore::runtime::EmbeddingCacheScheduler::GetInstance().Finalize(
575 !mindspore::distributed::cluster_exit_with_exception());
576 #endif
577
578 #ifdef ENABLE_MINDDATA
579 MS_LOG(INFO) << "Start releasing dataset handles...";
580 py::module iterators = py::module::import("mindspore.dataset.engine.iterators");
581 (void)iterators.attr("_cleanup")();
582 MS_LOG(INFO) << "End release dataset handles.";
583 #endif
584 mindspore::pipeline::FinalizeCluster();
585
586 // only in case that c++ calling python interface, ClearResAtexit should be called.
587 if (mindspore::python_adapter::IsPythonEnv()) {
588 mindspore::pipeline::ClearResAtexit();
589 }
590 }});
591
592 #ifndef ENABLE_SECURITY
593 (void)py::class_<EventWriter, std::shared_ptr<EventWriter>>(m, "EventWriter_")
594 .def(py::init<const std::string &>())
595 .def("GetFileName", &EventWriter::GetFileName, "Get the file name.")
596 .def("Open", &EventWriter::Open, "Open the write file.")
597 .def("Write", &EventWriter::Write, "Write the serialize event.")
598 .def("EventCount", &EventWriter::GetWriteEventCount, "Write event count.")
599 .def("Flush", &EventWriter::Flush, "Flush the event.")
600 .def("Close", &EventWriter::Close, "Close the write.")
601 .def("Shut", &EventWriter::Shut, "Final close the write.");
602 #endif // ENABLE_SECURITY
603
604 (void)py::class_<OpLib, std::shared_ptr<OpLib>>(m, "Oplib")
605 .def(py::init())
606 .def_static("reg_op", &OpLib::RegOp, "Register op info.");
607
608 (void)py::class_<CollectiveManager, std::shared_ptr<CollectiveManager>>(m, "CollectiveManager")
609 .def_static("get_instance", &CollectiveManager::instance, "Get collective manager instance.")
610 .def("initialized", &CollectiveManager::initialized, "Returns whether distributed module is initialized.")
611 .def("create_group", &CollectiveManager::CreateCommunicationGroup, "Create collective group.")
612 .def("destroy_group", &CollectiveManager::DestroyCommunicationGroup, "Destroy collective group.")
613 .def("get_local_rank_id", &CollectiveManager::GetLocalRankId, "Get the node rank id.")
614 .def("get_local_group_size", &CollectiveManager::GetLocalGroupSize, "Get the node rank id.")
615 .def("get_world_rank_from_group_rank", &CollectiveManager::GetWorldRankFromGroupRank,
616 "Get world rank by group rank.")
617 .def("get_group_rank_from_world_rank", &CollectiveManager::GetGroupRankFromWorldRank,
618 "Get group rank by world rank.")
619 .def("get_rank_id", &CollectiveManager::GetRankId, "Get the node rank id.")
620 .def("get_group_size", &CollectiveManager::GetGroupSize, "Get the nodes number in the collective communication.")
621 .def("get_group_ranks", &CollectiveManager::GetGroupRanks,
622 "Get group ranks for the specified communication group.");
623
624 (void)py::class_<PSContext, std::shared_ptr<PSContext>>(m, "PSContext")
625 .def_static("get_instance", &PSContext::instance, "Get PS context instance.")
626 .def("set_ps_enable", &PSContext::SetPSEnable, "Set PS mode enabled or disabled.")
627 .def("is_ps_mode", &PSContext::is_ps_mode, "Get PS mode enable-disable status.")
628 .def("reset", &PSContext::Reset, "Reset PS context attributes.")
629 .def("is_worker", &PSContext::is_worker, "Get whether the role of this process is Worker.")
630 .def("is_server", &PSContext::is_server, "Get whether the role of this process is PServer.")
631 .def("is_scheduler", &PSContext::is_scheduler, "Get whether the role of this process is Scheduler.")
632 .def("ps_rank_id", &PSContext::ps_rank_id, "Get Worker and PServer rank id.")
633 .def("insert_hash_table_size", &PSContext::InsertHashTableSize, "Insert hash table size.")
634 .def("reinsert_hash_table_size", &PSContext::ReInsertHashTableSize,
635 "Insert hash table size with new parameter name.")
636 .def("insert_accumu_init_info", &PSContext::InsertAccumuInitInfo, "Insert accumulation initialization value.")
637 .def("clone_hash_table", &PSContext::CloneHashTable, "Clone a hash table.")
638 .def("set_cache_enable", &PSContext::set_cache_enable, "Set ps mode cache enable or not.")
639 .def("set_cache_size", &PSContext::set_cache_size, "Set embedding cache size for ps cache mode.")
640 .def("cache_enable", &PSContext::cache_enable, "Get ps mode cache enable or not.")
641 .def("set_sparse_format", &PSContext::set_sparse_format, "Set the storage format of the embedding table.")
642 .def("set_rank_id", &PSContext::set_rank_id, "Set rank id for worker on ps mode.")
643 .def("set_server_mode", &PSContext::set_server_mode, "Set server mode.")
644 .def("server_mode", &PSContext::server_mode, "Get server mode.")
645 .def("set_ms_role", &PSContext::set_ms_role, "Set role for this process.")
646 .def("ms_role", &PSContext::ms_role, "Get role for this process.")
647 .def("set_worker_num", &PSContext::set_worker_num, "Set worker number.")
648 .def("worker_num", &PSContext::worker_num, "Get worker number.")
649 .def("set_server_num", &PSContext::set_server_num, "Set server number.")
650 .def("server_num", &PSContext::server_num, "Get server number.")
651 .def("set_scheduler_ip", &PSContext::set_scheduler_ip, "Set scheduler ip.")
652 .def("scheduler_ip", &PSContext::scheduler_ip, "Get scheduler ip.")
653 .def("set_scheduler_port", &PSContext::set_scheduler_port, "Set scheduler port.")
654 .def("scheduler_port", &PSContext::scheduler_port, "Get scheduler port.")
655 .def("set_scheduler_manage_port", &PSContext::set_scheduler_manage_port,
656 "Set scheduler manage port used to scale out/in.")
657 .def("scheduler_manage_port", &PSContext::scheduler_manage_port, "Get scheduler manage port used to scale out/in.")
658 .def("set_enable_ssl", &PSContext::set_enable_ssl, "Set PS SSL mode enabled or disabled.")
659 .def("enable_ssl", &PSContext::enable_ssl, "Get PS SSL mode enabled or disabled.")
660 .def("set_client_password", &PSContext::set_client_password, "Set the client password to decode the p12 file.")
661 .def("client_password", &PSContext::client_password, "Get the client password to decode the p12 file.")
662 .def("set_server_password", &PSContext::set_server_password, "Set the server password to decode the p12 file.")
663 .def("server_password", &PSContext::server_password, "Get the server password to decode the p12 file.")
664 .def("set_config_file_path", &PSContext::set_config_file_path,
665 "Set configuration files required by the communication layer.")
666 .def("config_file_path", &PSContext::config_file_path,
667 "Get configuration files required by the communication layer.")
668 .def("enable_distributed_mindrt", &PSContext::enable_distributed_mindrt, "Whether distributed MindRT is enabled.")
669 .def("set_checkpoint_load_status", &PSContext::set_checkpoint_load_status, "Set checkpoint load status.")
670 .def("store_warm_up_ptr_by_tensor", &PSContext::StoreWarmUpPtrByTensor, "Store warm up host cache by tensor.")
671 .def("store_warm_up_ptr_by_tensor_list", &PSContext::StoreWarmUpPtrByTensorList,
672 "Store warm up host cache by tensor list");
673 (void)m.def("_encrypt", &mindspore::pipeline::PyEncrypt, "Encrypt the data.");
674 (void)m.def("_decrypt", &mindspore::pipeline::PyDecrypt, "Decrypt the data.");
675 (void)m.def("_decrypt_data", &mindspore::pipeline::PyDecryptData, "Decrypt the bytes data.");
676 (void)m.def("_is_cipher_file", &mindspore::pipeline::PyIsCipherFile, "Determine whether the file is encrypted");
677
678 (void)py::class_<RecoveryContext, std::shared_ptr<RecoveryContext>>(m, "RecoveryContext")
679 .def_static("get_instance", &RecoveryContext::GetInstance, "Get recovery context instance.")
680 .def("enable_recovery", &RecoveryContext::enable_recovery, "Get whether enable recovery.")
681 .def("latest_ckpt_file", &RecoveryContext::latest_ckpt_file, "Get latest checkpoint file path.")
682 .def("latest_ckpt_epoch", &RecoveryContext::latest_ckpt_epoch, "Get the epoch of latest checkpoint.")
683 .def("latest_ckpt_step", &RecoveryContext::latest_ckpt_step, "Get the step of latest checkpoint.")
684 .def("set_need_reset", &RecoveryContext::set_need_reset,
685 "Set whether should call reset minddata and load ckpt for disaster recovery.")
686 .def("need_reset", &RecoveryContext::need_reset,
687 "Get whether should call reset minddata and load ckpt for disaster recovery.")
688 .def("recovery_path", &RecoveryContext::recovery_path,
689 "Get the recovery path used to save that need to be persisted.")
690 .def("ckpt_path", &RecoveryContext::GetCkptPath, "Get the recovery path used to save checkpoint.")
691 .def("set_ckpt_path", &RecoveryContext::SetCkptPath, "Set the recovery path used to save checkpoint.");
692
693 (void)py::class_<DeviceContextManager, std::shared_ptr<DeviceContextManager>>(m, "DeviceContextManager")
694 .def_static("get_instance", &DeviceContextManager::GetInstance, py::return_value_policy::reference,
695 "Get device context manager instance.")
696 .def("get_device_context", &DeviceContextManager::GetDeviceContext, "Return device context object.");
697 (void)py::class_<DeviceContext, std::shared_ptr<DeviceContext>>(m, "DeviceContext")
698 .def("initialized", &DeviceContext::initialized, "Return whether this device backend is successfully initialized.");
699 DeviceContextManager::GetInstance().RegisterDeviceStatelessFunc(&m);
700
701 (void)m.def("_ms_memory_recycle", &mindspore::pipeline::MemoryRecycle, "Recycle memory used by mindspore.");
702 (void)m.def("_bind_device_ctx", &mindspore::pipeline::BindDeviceCtx, "Bind device context to current thread");
703 (void)m.def("swap_cache", &mindspore::pipeline::SwapCache, py::arg("host"), py::arg("device"),
704 py::arg("block_mapping"), py::arg("is_device_to_host"), "Swap Cache for PageAttention.");
705 }
706