• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <cerrno>
17 #include <cstring>
18 #include <csignal>
19 #ifdef __linux__
20 #include <sys/prctl.h>
21 #endif
22 #include "mindrt/include/fork_utils.h"
23 #include "utils/log_adapter.h"
24 #include "utils/ms_context.h"
25 #include "mindrt/src/actor/actormgr.h"
26 #include "runtime/pynative/op_executor.h"
27 #include "pipeline/pynative/pynative_execute.h"
28 #include "pipeline/jit/ps/pipeline.h"
29 #include "include/common/thread_pool.h"
30 #include "include/common/pybind_api/api_register.h"
31 #include "runtime/hardware/device_context_manager.h"
32 
33 namespace mindspore {
34 
RegisterForkCallbacks()35 void RegisterForkCallbacks() {
36 #if !defined(_WIN32)
37   MS_LOG(DEBUG) << "Register MsContext fork callbacks.";
38   ForkUtils::GetInstance().RegisterCallbacks(MsContext::GetInstance(), static_cast<void (MsContext::*)()>(nullptr),
39                                              static_cast<void (MsContext::*)()>(nullptr), &MsContext::ChildAfterFork);
40   MS_LOG(DEBUG) << "Register ActorMgr fork callbacks.";
41   ForkUtils::GetInstance().RegisterCallbacks(ActorMgr::GetActorMgrRef(), static_cast<void (ActorMgr::*)()>(nullptr),
42                                              static_cast<void (ActorMgr::*)()>(nullptr), &ActorMgr::ChildAfterFork);
43   MS_LOG(DEBUG) << "Register Common ThreadPool fork callbacks.";
44   ForkUtils::GetInstance().RegisterCallbacks(
45     &common::ThreadPool::GetInstance(), static_cast<void (common::ThreadPool::*)()>(nullptr),
46     static_cast<void (common::ThreadPool::*)()>(nullptr), &common::ThreadPool::ChildAfterFork);
47   MS_LOG(DEBUG) << "Register PyNativeExecutor fork callbacks.";
48   ForkUtils::GetInstance().RegisterCallbacks(
49     pynative::PyNativeExecutor::GetInstance(), &pynative::PyNativeExecutor::ParentBeforeFork,
50     static_cast<void (pynative::PyNativeExecutor::*)()>(nullptr), &pynative::PyNativeExecutor::ChildAfterFork);
51   MS_LOG(DEBUG) << "Register GraphExecutorPy fork callbacks.";
52   ForkUtils::GetInstance().RegisterCallbacks(
53     pipeline::GraphExecutorPy::GetInstance(), &pipeline::GraphExecutorPy::ParentBeforeFork,
54     &pipeline::GraphExecutorPy::ParentAfterFork, &pipeline::GraphExecutorPy::ChildAfterFork);
55   MS_LOG(DEBUG) << "Register OpExecutor fork callbacks.";
56   ForkUtils::GetInstance().RegisterCallbacks(
57     &runtime::OpExecutor::GetInstance(), static_cast<void (runtime::OpExecutor::*)()>(nullptr),
58     static_cast<void (runtime::OpExecutor::*)()>(nullptr), &runtime::OpExecutor::ChildAfterFork);
59   MS_LOG(DEBUG) << "Register DeviceContextManager fork callbacks.";
60   ForkUtils::GetInstance().RegisterCallbacks(
61     &device::DeviceContextManager::GetInstance(), static_cast<void (device::DeviceContextManager::*)()>(nullptr),
62     static_cast<void (device::DeviceContextManager::*)()>(nullptr), &device::DeviceContextManager::ChildAfterFork);
63   MS_LOG(DEBUG) << "Register GraphScheduler fork callbacks.";
64   ForkUtils::GetInstance().RegisterCallbacks(
65     &runtime::GraphScheduler::GetInstance(), static_cast<void (runtime::GraphScheduler::*)()>(nullptr),
66     static_cast<void (runtime::GraphScheduler::*)()>(nullptr), &runtime::GraphScheduler::ChildAfterFork);
67 #endif
68 }
69 
PrepareBeforeFork()70 void PrepareBeforeFork() {
71   MS_LOG(DEBUG) << "Parent process before fork.";
72 
73   // Register fork callbacks when first fork event occurs.
74   static std::once_flag once_flag_;
75   std::call_once(once_flag_, [&]() {
76     MS_LOG(DEBUG) << "Register fork event callbacks.";
77     RegisterForkCallbacks();
78   });
79 
80   // Trigger ParentBeforeFork callbacks in parent process.
81   ForkUtils::GetInstance().BeforeFork();
82 
83   // If the forked thread does not hold the gil lock, we need to manually acquire the gil lock before forking,
84   // otherwise the child process will block when acquiring the gil lock.
85   ForkUtils::GetInstance().SetGilHoldBeforeFork(PyGILState_Check());
86   if (!ForkUtils::GetInstance().IsGilHoldBeforeFork()) {
87     MS_LOG(DEBUG) << "Acquire GIL lock in parent process before fork.";
88     ForkUtils::GetInstance().SetGilState(static_cast<int>(PyGILState_Ensure()));
89   }
90 }
91 
ParentAtFork()92 void ParentAtFork() {
93   MS_LOG(DEBUG) << "Parent process at fork.";
94 
95   // Release the gil lock that was acquired manually before forking.
96   if (!ForkUtils::GetInstance().IsGilHoldBeforeFork()) {
97     MS_LOG(DEBUG) << "Release GIL lock acquired manually before fork.";
98     PyGILState_Release(static_cast<PyGILState_STATE>(ForkUtils::GetInstance().GetGilState()));
99   }
100 
101   // Trigger ParentAfterFork callbacks in parent process.
102   ForkUtils::GetInstance().ParentAtFork();
103 }
104 
ChildAtFork()105 void ChildAtFork() {
106   MS_LOG(DEBUG) << "Child process at fork.";
107 
108   // Release the gil lock that was acquired manually before forking.
109   if (!ForkUtils::GetInstance().IsGilHoldBeforeFork()) {
110     MS_LOG(DEBUG) << "Release GIL lock acquired manually before fork.";
111     PyGILState_Release(static_cast<PyGILState_STATE>(ForkUtils::GetInstance().GetGilState()));
112   }
113 
114   // Trigger ChildAfterFork callbacks in child process.
115   ForkUtils::GetInstance().ChildAtFork();
116 }
117 
SetPDeathSig(int signal)118 void SetPDeathSig(int signal) {
119 #ifdef __linux__
120   // prctl(2) is a Linux specific system call.
121   // On other systems the following function call has no effect.
122   // This is set to ensure that non-daemonic child processes can
123   // terminate if their parent terminates before they do.
124   MS_LOG(DEBUG) << "Set prctl PR_SET_PDEATHSIG: " << signal;
125   auto res = prctl(PR_SET_PDEATHSIG, signal);
126   if (res < 0) {
127     MS_LOG(WARNING) << "Set prctl PR_SET_PDEATHSIG failed:(" << errno << ")" << strerror(errno);
128   }
129 #endif
130 }
131 
RegForkUtils(py::module * m)132 void RegForkUtils(py::module *m) {
133   auto m_sub = m->def_submodule("fork_utils", "submodule for fork");
134   (void)m_sub.def("prepare_before_fork", &PrepareBeforeFork, "Callback function called in parent process before fork");
135   (void)m_sub.def("parent_at_fork", &ParentAtFork, "Callback function called in parent process after fork");
136   (void)m_sub.def("child_at_fork", &ChildAtFork, "Callback function called in child process after fork");
137   (void)m_sub.def("prctl_set_pdeathsig", &SetPDeathSig, py::arg("signal"),
138                   "Set signal to child process after parent process is dead");
139 }
140 }  // namespace mindspore
141