• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_UTILS_FORK_UTILS_H
18 #define MINDSPORE_CORE_UTILS_FORK_UTILS_H
19 #include <cstdio>
20 #include <vector>
21 #include <string>
22 #include <functional>
23 #include <mutex>
24 #include <memory>
25 #include "utils/ms_utils.h"
26 
27 #ifdef FORK_UTILS_DEBUG
28 #define FORK_UTILS_LOG(content, args...) \
29   { std::printf("[FORK_UTILS] %s|%d: " #content "\r\n", __func__, __LINE__, ##args); }
30 #else
31 #define FORK_UTILS_LOG(content, ...)
32 #endif
33 
34 namespace mindspore {
35 struct fork_callback_info {
36   void *class_obj;
37   std::function<void()> before_fork_func;
38   std::function<void()> parent_atfork_func;
39   std::function<void()> child_atfork_func;
40 };
41 
42 MS_CORE_API void EmptyFunction();
43 
44 class MS_CORE_API ForkUtils {
45  public:
46   static ForkUtils &GetInstance() noexcept;
47 
48   template <class T>
RegisterCallbacks(std::shared_ptr<T> obj,void (T::* before_fork)(),void (T::* parent_atfork)(),void (T::* child_atfork)())49   void RegisterCallbacks(std::shared_ptr<T> obj, void (T::*before_fork)(), void (T::*parent_atfork)(),
50                          void (T::*child_atfork)()) {
51     RegisterCallbacks(obj.get(), before_fork, parent_atfork, child_atfork);
52   }
53 
54   template <class T>
RegisterCallbacks(T * obj,void (T::* before_fork)(),void (T::* parent_atfork)(),void (T::* child_atfork)())55   void RegisterCallbacks(T *obj, void (T::*before_fork)(), void (T::*parent_atfork)(), void (T::*child_atfork)()) {
56 #if !defined(_WIN32) && !defined(BUILD_LITE)
57     FORK_UTILS_LOG("Register fork callback info.");
58 
59     struct fork_callback_info callback_info = {obj, EmptyFunction, EmptyFunction, EmptyFunction};
60     if (before_fork) {
61       callback_info.before_fork_func = std::bind(before_fork, obj);
62     }
63     if (parent_atfork) {
64       callback_info.parent_atfork_func = std::bind(parent_atfork, obj);
65     }
66     if (child_atfork) {
67       callback_info.child_atfork_func = std::bind(child_atfork, obj);
68     }
69 
70     bool exist_ = false;
71     for (auto &iter : fork_callbacks_) {
72       FORK_UTILS_LOG("Callback_info already exist, update info.");
73       if (iter.class_obj == obj) {
74         exist_ = true;
75         iter = callback_info;
76         break;
77       }
78     }
79     if (exist_ == false) {
80       FORK_UTILS_LOG("Create new callback info.");
81       fork_callbacks_.push_back(callback_info);
82     }
83 #endif
84   }
85 
86   // Note: Do not call this deregistration interface in the destructor of a global object or a singleton object,
87   // because fork_utils object may be destructed before them.
88   template <class T>
DeregCallbacks(const T * obj)89   void DeregCallbacks(const T *obj) noexcept {
90 #if !defined(_WIN32) && !defined(BUILD_LITE)
91     try {
92       FORK_UTILS_LOG("Deregister fork callback info.");
93       for (auto iter = fork_callbacks_.begin(); iter != fork_callbacks_.end(); (void)++iter) {
94         if (iter->class_obj == obj) {
95           (void)fork_callbacks_.erase(iter);
96           break;
97         }
98       }
99     } catch (const std::exception &e) {
100       FORK_UTILS_LOG("Deregister fork callback info failed: %s", e.what());
101     } catch (...) {
102       FORK_UTILS_LOG("Deregister fork callback info failed: Unknown exception.");
103     }
104 #endif
105   }
106 
GetCallbacks()107   std::vector<fork_callback_info> GetCallbacks() { return fork_callbacks_; }
108 
SetGilHoldBeforeFork(bool gil_hold_before_fork)109   void SetGilHoldBeforeFork(bool gil_hold_before_fork) { hold_gil_before_fork_ = gil_hold_before_fork; }
110 
IsGilHoldBeforeFork()111   bool IsGilHoldBeforeFork() const { return hold_gil_before_fork_; }
112 
SetGilState(int gil_state)113   void SetGilState(int gil_state) { gil_state_ = gil_state; }
114 
GetGilState()115   int GetGilState() const { return gil_state_; }
116 
117   void BeforeFork();
118 
119   void ParentAtFork();
120 
121   void ChildAtFork();
122 
123  private:
124   ForkUtils() = default;
125   ~ForkUtils() = default;
126   std::vector<fork_callback_info> fork_callbacks_;
127   // Record whether forked thread holds the gil lock when the fork occurs.
128   bool hold_gil_before_fork_ = false;
129   int gil_state_ = 0;
130 };
131 }  // namespace mindspore
132 #endif  // MINDSPORE_CORE_UTILS_FORK_UTILS_H
133