1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <cerrno>
17 #include <cstring>
18 #include <csignal>
19 #ifdef __linux__
20 #include <sys/prctl.h>
21 #endif
22 #include "mindrt/include/fork_utils.h"
23 #include "utils/log_adapter.h"
24 #include "utils/ms_context.h"
25 #include "mindrt/src/actor/actormgr.h"
26 #include "runtime/pynative/op_executor.h"
27 #include "pipeline/pynative/pynative_execute.h"
28 #include "pipeline/jit/ps/pipeline.h"
29 #include "include/common/thread_pool.h"
30 #include "include/common/pybind_api/api_register.h"
31 #include "runtime/hardware/device_context_manager.h"
32
33 namespace mindspore {
34
RegisterForkCallbacks()35 void RegisterForkCallbacks() {
36 #if !defined(_WIN32)
37 MS_LOG(DEBUG) << "Register MsContext fork callbacks.";
38 ForkUtils::GetInstance().RegisterCallbacks(MsContext::GetInstance(), static_cast<void (MsContext::*)()>(nullptr),
39 static_cast<void (MsContext::*)()>(nullptr), &MsContext::ChildAfterFork);
40 MS_LOG(DEBUG) << "Register ActorMgr fork callbacks.";
41 ForkUtils::GetInstance().RegisterCallbacks(ActorMgr::GetActorMgrRef(), static_cast<void (ActorMgr::*)()>(nullptr),
42 static_cast<void (ActorMgr::*)()>(nullptr), &ActorMgr::ChildAfterFork);
43 MS_LOG(DEBUG) << "Register Common ThreadPool fork callbacks.";
44 ForkUtils::GetInstance().RegisterCallbacks(
45 &common::ThreadPool::GetInstance(), static_cast<void (common::ThreadPool::*)()>(nullptr),
46 static_cast<void (common::ThreadPool::*)()>(nullptr), &common::ThreadPool::ChildAfterFork);
47 MS_LOG(DEBUG) << "Register PyNativeExecutor fork callbacks.";
48 ForkUtils::GetInstance().RegisterCallbacks(
49 pynative::PyNativeExecutor::GetInstance(), &pynative::PyNativeExecutor::ParentBeforeFork,
50 static_cast<void (pynative::PyNativeExecutor::*)()>(nullptr), &pynative::PyNativeExecutor::ChildAfterFork);
51 MS_LOG(DEBUG) << "Register GraphExecutorPy fork callbacks.";
52 ForkUtils::GetInstance().RegisterCallbacks(
53 pipeline::GraphExecutorPy::GetInstance(), &pipeline::GraphExecutorPy::ParentBeforeFork,
54 &pipeline::GraphExecutorPy::ParentAfterFork, &pipeline::GraphExecutorPy::ChildAfterFork);
55 MS_LOG(DEBUG) << "Register OpExecutor fork callbacks.";
56 ForkUtils::GetInstance().RegisterCallbacks(
57 &runtime::OpExecutor::GetInstance(), static_cast<void (runtime::OpExecutor::*)()>(nullptr),
58 static_cast<void (runtime::OpExecutor::*)()>(nullptr), &runtime::OpExecutor::ChildAfterFork);
59 MS_LOG(DEBUG) << "Register DeviceContextManager fork callbacks.";
60 ForkUtils::GetInstance().RegisterCallbacks(
61 &device::DeviceContextManager::GetInstance(), static_cast<void (device::DeviceContextManager::*)()>(nullptr),
62 static_cast<void (device::DeviceContextManager::*)()>(nullptr), &device::DeviceContextManager::ChildAfterFork);
63 MS_LOG(DEBUG) << "Register GraphScheduler fork callbacks.";
64 ForkUtils::GetInstance().RegisterCallbacks(
65 &runtime::GraphScheduler::GetInstance(), static_cast<void (runtime::GraphScheduler::*)()>(nullptr),
66 static_cast<void (runtime::GraphScheduler::*)()>(nullptr), &runtime::GraphScheduler::ChildAfterFork);
67 #endif
68 }
69
PrepareBeforeFork()70 void PrepareBeforeFork() {
71 MS_LOG(DEBUG) << "Parent process before fork.";
72
73 // Register fork callbacks when first fork event occurs.
74 static std::once_flag once_flag_;
75 std::call_once(once_flag_, [&]() {
76 MS_LOG(DEBUG) << "Register fork event callbacks.";
77 RegisterForkCallbacks();
78 });
79
80 // Trigger ParentBeforeFork callbacks in parent process.
81 ForkUtils::GetInstance().BeforeFork();
82
83 // If the forked thread does not hold the gil lock, we need to manually acquire the gil lock before forking,
84 // otherwise the child process will block when acquiring the gil lock.
85 ForkUtils::GetInstance().SetGilHoldBeforeFork(PyGILState_Check());
86 if (!ForkUtils::GetInstance().IsGilHoldBeforeFork()) {
87 MS_LOG(DEBUG) << "Acquire GIL lock in parent process before fork.";
88 ForkUtils::GetInstance().SetGilState(static_cast<int>(PyGILState_Ensure()));
89 }
90 }
91
ParentAtFork()92 void ParentAtFork() {
93 MS_LOG(DEBUG) << "Parent process at fork.";
94
95 // Release the gil lock that was acquired manually before forking.
96 if (!ForkUtils::GetInstance().IsGilHoldBeforeFork()) {
97 MS_LOG(DEBUG) << "Release GIL lock acquired manually before fork.";
98 PyGILState_Release(static_cast<PyGILState_STATE>(ForkUtils::GetInstance().GetGilState()));
99 }
100
101 // Trigger ParentAfterFork callbacks in parent process.
102 ForkUtils::GetInstance().ParentAtFork();
103 }
104
ChildAtFork()105 void ChildAtFork() {
106 MS_LOG(DEBUG) << "Child process at fork.";
107
108 // Release the gil lock that was acquired manually before forking.
109 if (!ForkUtils::GetInstance().IsGilHoldBeforeFork()) {
110 MS_LOG(DEBUG) << "Release GIL lock acquired manually before fork.";
111 PyGILState_Release(static_cast<PyGILState_STATE>(ForkUtils::GetInstance().GetGilState()));
112 }
113
114 // Trigger ChildAfterFork callbacks in child process.
115 ForkUtils::GetInstance().ChildAtFork();
116 }
117
SetPDeathSig(int signal)118 void SetPDeathSig(int signal) {
119 #ifdef __linux__
120 // prctl(2) is a Linux specific system call.
121 // On other systems the following function call has no effect.
122 // This is set to ensure that non-daemonic child processes can
123 // terminate if their parent terminates before they do.
124 MS_LOG(DEBUG) << "Set prctl PR_SET_PDEATHSIG: " << signal;
125 auto res = prctl(PR_SET_PDEATHSIG, signal);
126 if (res < 0) {
127 MS_LOG(WARNING) << "Set prctl PR_SET_PDEATHSIG failed:(" << errno << ")" << strerror(errno);
128 }
129 #endif
130 }
131
RegForkUtils(py::module * m)132 void RegForkUtils(py::module *m) {
133 auto m_sub = m->def_submodule("fork_utils", "submodule for fork");
134 (void)m_sub.def("prepare_before_fork", &PrepareBeforeFork, "Callback function called in parent process before fork");
135 (void)m_sub.def("parent_at_fork", &ParentAtFork, "Callback function called in parent process after fork");
136 (void)m_sub.def("child_at_fork", &ChildAtFork, "Callback function called in child process after fork");
137 (void)m_sub.def("prctl_set_pdeathsig", &SetPDeathSig, py::arg("signal"),
138 "Set signal to child process after parent process is dead");
139 }
140 } // namespace mindspore
141