1 /**
2 * Copyright 2019-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "pipeline/pynative/pynative_execute.h"
18 #include "pipeline/pynative/pynative_utils.h"
19 #include "pipeline/pynative/grad/ir/ir_bprop.h"
20 #include "pipeline/pynative/predict_out_type_map.h"
21 #include "pipeline/jit/ps/debug/trace.h"
22 #include "pybind_api/pybind_patch.h"
23 #include "pybind_api/gil_scoped_long_running.h"
24 #include "pybind_api/ir/hook_py.h"
25 #include "include/common/utils/config_manager.h"
26 #include "include/common/pybind_api/api_register.h"
27 #include "frontend/optimizer/ad/grad.h"
28 #include "pipeline/jit/ps/pass.h"
29 #include "runtime/pynative/op_executor.h"
30 #include "runtime/pynative/op_compiler.h"
31 #include "runtime/pynative/op_runner.h"
32 #include "include/common/profiler.h"
33 #include "ir/cell.h"
34 #include "include/common/utils/stub_tensor.h"
35 #include "include/common/utils/python_utils.h"
36 #include "frontend/operator/ops_front_infer_function.h"
37 #include "kernel/kernel_mod_cache.h"
38 #include "runtime/pipeline/pipeline.h"
39
40 namespace mindspore::pynative {
41 std::shared_ptr<PyNativeExecutor> PyNativeExecutor::executor_ = nullptr;
42 ForwardExecutorPtr PyNativeExecutor::forward_executor_ = nullptr;
43 GradExecutorPtr PyNativeExecutor::grad_executor_ = nullptr;
44 std::mutex PyNativeExecutor::instance_lock_;
45
46 namespace {
47 template <typename T, typename... Args>
PyNativeExecutorTry(const std::function<T (const Args &...)> & method,const Args &...args)48 T PyNativeExecutorTry(const std::function<T(const Args &...)> &method, const Args &... args) {
49 const auto &inst = PyNativeExecutor::GetInstance();
50 MS_EXCEPTION_IF_NULL(inst);
51 MS_EXCEPTION_IF_NULL(method);
52 auto already_set_error_handler = [&inst]() {
53 // Print function call stack info before release.
54 std::ostringstream oss;
55 trace::TraceGraphEval();
56 trace::GetEvalStackInfo(oss);
57 // Call py::print to output function call stack to STDOUT, in case of output the log to file, the user can see
58 // these info from screen, no need to open log file to find these info.
59 py::print(oss.str());
60 MS_LOG(ERROR) << oss.str();
61 inst->ClearRes();
62 };
63
64 if constexpr (std::is_same_v<T, void>) {
65 HandleExceptionRethrow([&method, &args...]() { method(args...); }, already_set_error_handler,
66 [&inst]() { inst->ClearRes(); }, [&inst]() { inst->ClearRes(); });
67 } else {
68 T res;
69 HandleExceptionRethrow([&res, &method, &args...]() { res = method(args...); }, already_set_error_handler,
70 [&inst]() { inst->ClearRes(); }, [&inst]() { inst->ClearRes(); });
71 return res;
72 }
73 }
74
75 // Tensor may be used before the execution of the asynchronous task.
SetCallbackForInputTensor(const std::vector<ValuePtr> & input_values)76 void SetCallbackForInputTensor(const std::vector<ValuePtr> &input_values) {
77 for (auto &input : input_values) {
78 MS_EXCEPTION_IF_NULL(input);
79 if (input->isa<tensor::BaseTensor>()) {
80 auto tensor = input->cast<tensor::BaseTensorPtr>();
81 MS_EXCEPTION_IF_NULL(tensor);
82 tensor->set_need_pipeline_sync(true);
83 }
84 }
85 }
86 } // namespace
87
StoreAsyncStatus(const FrontendOpRunInfoPtr & op_run_info) const88 void PyNativeExecutor::StoreAsyncStatus(const FrontendOpRunInfoPtr &op_run_info) const {
89 // Pure function running or cell not set mix precision
90 op_run_info->async_status.disable_mix_precision =
91 (forward_executor()->IsFirstCell() || forward_executor()->CellNotSetMixedPrecision(op_run_info));
92 op_run_info->async_status.is_jit_compiling = forward_executor()->is_jit_compiling();
93 op_run_info->async_status.custom_bprop_cell_count = grad_executor()->custom_bprop_cell_count();
94 }
95
RunOpStub(const py::args & args) const96 py::object PyNativeExecutor::RunOpStub(const py::args &args) const {
97 FrontendOpRunInfoPtr op_run_info = forward_executor()->GenerateOpRunInfo(args, true);
98 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kRunOp,
99 op_run_info->base_op_run_info.op_name, false, true);
100 SetCallbackForInputTensor(op_run_info->op_grad_info->input_value);
101
102 StoreAsyncStatus(op_run_info);
103 const auto &op_name = op_run_info->base_op_run_info.op_name;
104 // 1. get top_type from Primitive::PredictOutputType
105 auto top_type = PredictOutType(op_run_info);
106 // 2. if disable PyTraceAsync, return after infer(half-asynchronous) or run(synchronous mode)
107 if (!forward_executor()->EnablePipeline(op_name)) {
108 // Wait for async task finish
109 forward_executor()->WaitForwardTask();
110 // RunOp sync
111 PyNativeExecutorTry(forward_executor()->RunOpS, op_run_info);
112 return PyNativeAlgo::DataConvert::ValueToPyObj(op_run_info->real_out);
113 }
114 // 3. create top stub node
115 auto node = stub::MakeTopNode(top_type);
116 // The task in the AsyncQueue may need to acquire gil.
117 GilReleaseWithCheck release_gil;
118 // 4. set abstract and value in asynchronous thread after infer and run
119 op_run_info->stub_output = node.second;
120 forward_executor()->DispatchFrontendTask(op_run_info);
121 // 5. return stub node
122 return node.first;
123 }
124
RunSliceOpStub(const std::vector<ValuePtr> & input_values,const std::vector<SliceOpInfoPtr> & slice_op_infos) const125 py::object PyNativeExecutor::RunSliceOpStub(const std::vector<ValuePtr> &input_values,
126 const std::vector<SliceOpInfoPtr> &slice_op_infos) const {
127 runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kRunOp);
128 SetCallbackForInputTensor(input_values);
129 auto requires_grad = grad_executor()->RequiresGrad();
130 if (!forward_executor()->EnablePipeline("")) {
131 forward_executor()->WaitForwardTask();
132 auto ret = forward_executor()->RunSliceOpFrontend(input_values, slice_op_infos, requires_grad, nullptr);
133 return PyNativeAlgo::DataConvert::ValueToPyObj(ret);
134 }
135 auto top_type = kTensorType;
136 auto node = stub::MakeTopNode(top_type);
137 GilReleaseWithCheck release_gil;
138 forward_executor()->DispatchSilceOpFrontendTask(input_values, slice_op_infos, requires_grad, node.second);
139 return node.first;
140 }
141
RealRunOp(const py::args & args) const142 py::object PyNativeExecutor::RealRunOp(const py::args &args) const {
143 FrontendOpRunInfoPtr op_run_info = forward_executor()->GenerateOpRunInfo(args);
144 StoreAsyncStatus(op_run_info);
145 PyNativeExecutorTry(forward_executor()->RunOpS, op_run_info);
146 if (PyGILState_Check() == 0) {
147 py::gil_scoped_acquire acquire;
148 return PyNativeAlgo::DataConvert::ValueToPyObj(op_run_info->real_out);
149 } else {
150 return PyNativeAlgo::DataConvert::ValueToPyObj(op_run_info->real_out);
151 }
152 }
153
CallConstantFolding(const py::args & args) const154 py::object PyNativeExecutor::CallConstantFolding(const py::args &args) const {
155 return forward_executor()->infer_operation()->CallConstantFolding(args);
156 }
157
set_py_exe_path(const py::object & py_exe_path) const158 void PyNativeExecutor::set_py_exe_path(const py::object &py_exe_path) const {
159 if (!py::isinstance<py::str>(py_exe_path)) {
160 MS_LOG(EXCEPTION) << "Failed, py_exe_path input is not a str";
161 }
162 const auto &py_exe_path_s = py_exe_path.cast<std::string>();
163 auto ms_context = MsContext::GetInstance();
164 ms_context->set_param<std::string>(MS_CTX_PYTHON_EXE_PATH, py_exe_path_s);
165 }
166
set_kernel_build_server_dir(const py::object & kernel_build_server_dir) const167 void PyNativeExecutor::set_kernel_build_server_dir(const py::object &kernel_build_server_dir) const {
168 if (!py::isinstance<py::str>(kernel_build_server_dir)) {
169 MS_LOG(EXCEPTION) << "Failed, kernel_build_server_dir input is not a str";
170 }
171 const auto &kernel_build_server_dir_s = kernel_build_server_dir.cast<std::string>();
172 auto ms_context = MsContext::GetInstance();
173 ms_context->set_param<std::string>(MS_CTX_KERNEL_BUILD_SERVER_DIR, kernel_build_server_dir_s);
174 }
175
ClearRes() const176 void PyNativeExecutor::ClearRes() const {
177 forward_executor()->WaitForwardTask();
178 runtime::OpExecutor::GetInstance().Wait();
179 // Clear forward tasks before clear op graphs cache.
180 pynative::OpCompiler::GetInstance().ClearAllCache();
181 kernel::KernelModCache::GetInstance().ClearAllCache();
182 pynative::autograd::ClearAutoGradCache();
183 tensor::RegisterHook::ClearHookMap();
184
185 // Maybe exit in runop step
186 auto ms_context = MsContext::GetInstance();
187 if (ms_context != nullptr) {
188 ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, false);
189 }
190 ConfigManager::GetInstance().ResetIterNum();
191 if (forward_executor_ != nullptr) {
192 forward_executor_->ClearRes();
193 }
194 if (grad_executor_ != nullptr) {
195 grad_executor_->ClearRes();
196 }
197 ad::CleanRes();
198 pipeline::ReclaimOptimizer();
199 MS_LOG(DEBUG) << "Clear all res";
200 }
201
Init()202 void PyNativeExecutor::Init() {
203 MS_LOG(DEBUG) << "Init PyNativeExecutor";
204 forward_executor_ = std::make_shared<ForwardExecutor>();
205 forward_executor_->Init();
206 grad_executor_ = std::make_shared<GradExecutor>(forward_executor_);
207 grad_executor_->Init();
208 forward_executor_->set_grad_executor(grad_executor_);
209 forward_executor_->RefreshForwardCallback();
210 runtime::ProfilerAnalyzer::GetInstance().SetThreadIdToName(std::this_thread::get_id(), "Python");
211 }
212
Sync() const213 void PyNativeExecutor::Sync() const {
214 forward_executor()->Sync();
215 runtime::ProfilerAnalyzer::GetInstance().EndStep();
216 runtime::ProfilerAnalyzer::GetInstance().StartStep();
217 }
218
SetHookChanged(const py::object & cell) const219 void PyNativeExecutor::SetHookChanged(const py::object &cell) const {
220 if (!py::isinstance<Cell>(cell)) {
221 MS_LOG(EXCEPTION) << "The 'set_hook_changed' function is only supported on Cell object!";
222 }
223 grad_executor()->SetHookChanged(cell);
224 }
225
grad_flag() const226 bool PyNativeExecutor::grad_flag() const { return grad_executor()->grad_flag(); }
227
set_grad_flag(bool flag) const228 void PyNativeExecutor::set_grad_flag(bool flag) const { grad_executor()->set_grad_flag(flag); }
229
enable_grad() const230 bool PyNativeExecutor::enable_grad() const { return grad_executor()->enable_grad(); }
231
set_enable_grad(bool enable_grad) const232 void PyNativeExecutor::set_enable_grad(bool enable_grad) const { grad_executor()->set_enable_grad(enable_grad); }
233
CheckAlreadyRun(const prim::GradOperationPtr & grad,const py::object & obj,const py::object & weights,const py::object & grad_hash_id,const py::args & args) const234 py::object PyNativeExecutor::CheckAlreadyRun(const prim::GradOperationPtr &grad, const py::object &obj,
235 const py::object &weights, const py::object &grad_hash_id,
236 const py::args &args) const {
237 return grad_executor()->CheckAlreadyRun(grad, obj, weights, grad_hash_id, args);
238 }
239
NewGraph(const py::object & obj,const py::args & args) const240 void PyNativeExecutor::NewGraph(const py::object &obj, const py::args &args) const {
241 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeNewGraph,
242 runtime::ProfilerRecorder::kNoName, false);
243 forward_executor()->ProcessBeforeNewGraph(obj, args);
244
245 if (!grad_executor()->RequiresGrad()) {
246 MS_LOG(DEBUG) << "Grad flag is false";
247 return;
248 }
249 PyNativeExecutorTry(grad_executor()->InitGraph, obj, args);
250 forward_executor()->ProcessAfterNewGraph(obj);
251 }
252
EndGraph(const py::object & obj,const py::object & out,const py::args & args) const253 void PyNativeExecutor::EndGraph(const py::object &obj, const py::object &out, const py::args &args) const {
254 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeEndGraph,
255 runtime::ProfilerRecorder::kNoName, false);
256 bool is_cell = py::isinstance<Cell>(obj);
257 forward_executor()->ProcessBeforeEndGraph(obj, is_cell);
258
259 if (!grad_executor()->RequiresGrad()) {
260 MS_LOG(DEBUG) << "Grad flag is false";
261 return;
262 }
263 PyNativeExecutorTry(grad_executor()->LinkGraph, obj, out, args);
264 forward_executor()->ProcessAfterEndGraph(obj, is_cell);
265 }
266
RunGrad(const prim::GradOperationPtr & grad,const py::object & cell,const py::object & weights,const py::object & grad_position,const py::args & args) const267 py::object PyNativeExecutor::RunGrad(const prim::GradOperationPtr &grad, const py::object &cell,
268 const py::object &weights, const py::object &grad_position,
269 const py::args &args) const {
270 runtime::ProfilerStageRecorder recorder(runtime::ProfilerStage::kRunGrad);
271 return PyNativeExecutorTry(grad_executor()->Run, grad, cell, weights, grad_position, args);
272 }
273
GradJit(const py::object & out,const py::args & args) const274 py::object PyNativeExecutor::GradJit(const py::object &out, const py::args &args) const {
275 const auto &ret = grad_executor()->jit()->GradJit(out, args);
276 return ret;
277 }
278
IsFirstCell() const279 bool PyNativeExecutor::IsFirstCell() const { return forward_executor()->IsFirstCell(); }
280
WorkerJoin()281 void PyNativeExecutor::WorkerJoin() {
282 GilReleaseWithCheck release_gil;
283 runtime::Pipeline::Get().frontend_stage()->WorkerJoin();
284 }
285
SetJitCompileStatus(bool is_compiling,const std::string & phase) const286 void PyNativeExecutor::SetJitCompileStatus(bool is_compiling, const std::string &phase) const {
287 forward_executor()->set_is_jit_compiling(is_compiling);
288 grad_executor()->jit()->set_graph_phase(phase);
289 }
290
SetIsRunRecompute(bool is_runing_recompute) const291 void PyNativeExecutor::SetIsRunRecompute(bool is_runing_recompute) const {
292 grad_executor()->set_is_run_recompute(is_runing_recompute);
293 }
294
SetDynamicInput(const py::object & obj,const py::args & args) const295 void PyNativeExecutor::SetDynamicInput(const py::object &obj, const py::args &args) const {
296 grad_executor()->SaveDynamicInputsCells(obj, args);
297 if (grad_executor()->dynamic_shape()->enable_unknown_shape()) {
298 grad_executor()->dynamic_shape()->SetDynamicInput(obj, args);
299 }
300 }
301
GetDynamicInput(const py::object & actual_input) const302 py::object PyNativeExecutor::GetDynamicInput(const py::object &actual_input) const {
303 if (grad_executor()->dynamic_shape()->enable_unknown_shape()) {
304 MS_LOG(DEBUG) << "Get dynamic shape for jit";
305 return grad_executor()->dynamic_shape()->GetDynamicInput(actual_input);
306 }
307 return actual_input;
308 }
309
ParentBeforeFork()310 void PyNativeExecutor::ParentBeforeFork() {
311 MS_LOG(DEBUG) << "PyNativeExecutor prepare before fork.";
312 MS_LOG(DEBUG) << "Wait for OpExecutor.";
313 runtime::OpExecutor::GetInstance().WaitAll();
314 MS_LOG(DEBUG) << "Wait for grad_executor_.";
315 grad_executor_->bprop_queue()->Wait();
316 MS_LOG(DEBUG) << "PyNativeExecutor prepare before fork done.";
317 }
318
ChildAfterFork()319 void PyNativeExecutor::ChildAfterFork() {
320 MS_LOG(DEBUG) << "PyNativeExecutor reinitialize after fork.";
321 MS_LOG(DEBUG) << "Clear OpCompiler Cache.";
322 pynative::OpCompiler::GetInstance().ClearAllCache();
323 if (forward_executor_ != nullptr) {
324 MS_LOG(DEBUG) << "Clear forward_executor_ resources.";
325 forward_executor_->ClearRes();
326 // Call ForwardExecutor::ReInit() to update device_target_
327 forward_executor_->ReInit();
328 MS_LOG(DEBUG) << "Reinitialize forward_executor_.";
329 forward_executor_->ChildAfterFork();
330 }
331 // Reset PyNativeExecutor resources
332 if (grad_executor_ != nullptr) {
333 MS_LOG(DEBUG) << "Clear grad_executor_ resources.";
334 grad_executor_->ClearRes();
335 MS_LOG(DEBUG) << "Reinitialize grad_executor_.";
336 grad_executor_->ChildAfterFork();
337 }
338 runtime::OpRunner::ChildAfterFork();
339 MS_LOG(DEBUG) << "PyNativeExecutor reinitialize after fork done.";
340 }
341
SetAsyncForGraph(bool flag) const342 void PyNativeExecutor::SetAsyncForGraph(bool flag) const {
343 runtime::OpExecutor::GetInstance().set_async_for_graph(flag);
344 }
345
RegPyNativeExecutor(const py::module * m)346 void RegPyNativeExecutor(const py::module *m) {
347 stub::RegStubNodes(m);
348
349 (void)py::class_<PyNativeExecutor, std::shared_ptr<PyNativeExecutor>>(*m, "PyNativeExecutor_")
350 .def_static("get_instance", &PyNativeExecutor::GetInstance, "PyNativeExecutor get_instance.")
351 .def("is_first_cell", &PyNativeExecutor::IsFirstCell, "check if the first cell.")
352 .def("new_graph", &PyNativeExecutor::NewGraph, "pynative new a graph.")
353 .def("end_graph", &PyNativeExecutor::EndGraph, "pynative end a graph.")
354 .def("check_run", &PyNativeExecutor::CheckAlreadyRun, "pynative check graph run before.")
355 .def("grad_jit", &PyNativeExecutor::GradJit, "pynative grad for jit.")
356 .def("clear_res", &PyNativeExecutor::ClearRes, "pynative clear exception res.")
357 .def("sync", &PyNativeExecutor::Sync, "pynative sync stream.")
358 .def("grad", &PyNativeExecutor::RunGrad, "pynative executor run grad.")
359 .def("grad_flag", &PyNativeExecutor::grad_flag, "pynative grad flag")
360 .def("enable_grad", &PyNativeExecutor::enable_grad, "pynative enable grad, used for with no_grad")
361 .def("set_hook_changed", &PyNativeExecutor::SetHookChanged, "set pynative hook changed")
362 .def("set_grad_flag", &PyNativeExecutor::set_grad_flag, py::arg("flag") = py::bool_(false),
363 "Executor set grad flag.")
364 .def("set_enable_grad", &PyNativeExecutor::set_enable_grad, py::arg("enable_grad") = py::bool_(true),
365 "pynative set enable grad")
366 .def("set_dynamic_input", &PyNativeExecutor::SetDynamicInput, "set dynamic input")
367 .def("get_dynamic_input", &PyNativeExecutor::GetDynamicInput, "get dynamic input")
368 .def("set_py_exe_path", &PyNativeExecutor::set_py_exe_path, py::arg("py_exe_path") = py::str(""),
369 "set python executable path.")
370 .def("set_kernel_build_server_dir", &PyNativeExecutor::set_kernel_build_server_dir,
371 py::arg("kernel_build_server_dir") = py::str(""), "set kernel build server directory path.")
372 .def("set_jit_compile_status", &PyNativeExecutor::SetJitCompileStatus, "set jit compile status.")
373 .def("set_is_run_recompute", &PyNativeExecutor::SetIsRunRecompute, "set grad is in recompile status.")
374 .def("run_op_async", &PyNativeExecutor::RunOpStub, "run op asynchronously")
375 .def("set_async_for_graph", &PyNativeExecutor::SetAsyncForGraph, py::arg("flag") = py::bool_(false),
376 "Executor set async flag.")
377 .def("constant_folding", &PyNativeExecutor::CallConstantFolding, "Call Constant Folding Primitive");
378 }
379 } // namespace mindspore::pynative
380