1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_UTILS_FORK_UTILS_H 18 #define MINDSPORE_CORE_UTILS_FORK_UTILS_H 19 #include <cstdio> 20 #include <vector> 21 #include <string> 22 #include <functional> 23 #include <mutex> 24 #include <memory> 25 #include "utils/ms_utils.h" 26 27 #ifdef FORK_UTILS_DEBUG 28 #define FORK_UTILS_LOG(content, args...) \ 29 { std::printf("[FORK_UTILS] %s|%d: " #content "\r\n", __func__, __LINE__, ##args); } 30 #else 31 #define FORK_UTILS_LOG(content, ...) 32 #endif 33 34 namespace mindspore { 35 struct fork_callback_info { 36 void *class_obj; 37 std::function<void()> before_fork_func; 38 std::function<void()> parent_atfork_func; 39 std::function<void()> child_atfork_func; 40 }; 41 42 MS_CORE_API void EmptyFunction(); 43 44 class MS_CORE_API ForkUtils { 45 public: 46 static ForkUtils &GetInstance() noexcept; 47 48 template <class T> RegisterCallbacks(std::shared_ptr<T> obj,void (T::* before_fork)(),void (T::* parent_atfork)(),void (T::* child_atfork)())49 void RegisterCallbacks(std::shared_ptr<T> obj, void (T::*before_fork)(), void (T::*parent_atfork)(), 50 void (T::*child_atfork)()) { 51 RegisterCallbacks(obj.get(), before_fork, parent_atfork, child_atfork); 52 } 53 54 template <class T> RegisterCallbacks(T * obj,void (T::* before_fork)(),void (T::* parent_atfork)(),void (T::* child_atfork)())55 void RegisterCallbacks(T *obj, void (T::*before_fork)(), void (T::*parent_atfork)(), void (T::*child_atfork)()) { 56 #if !defined(_WIN32) && !defined(BUILD_LITE) 57 FORK_UTILS_LOG("Register fork callback info."); 58 59 struct fork_callback_info callback_info = {obj, EmptyFunction, EmptyFunction, EmptyFunction}; 60 if (before_fork) { 61 callback_info.before_fork_func = std::bind(before_fork, obj); 62 } 63 if (parent_atfork) { 64 callback_info.parent_atfork_func = std::bind(parent_atfork, obj); 65 } 66 if (child_atfork) { 67 callback_info.child_atfork_func = std::bind(child_atfork, obj); 68 } 69 70 bool exist_ = false; 71 for (auto &iter : fork_callbacks_) { 72 FORK_UTILS_LOG("Callback_info already exist, update info."); 73 if (iter.class_obj == obj) { 74 exist_ = true; 75 iter = callback_info; 76 break; 77 } 78 } 79 if (exist_ == false) { 80 FORK_UTILS_LOG("Create new callback info."); 81 fork_callbacks_.push_back(callback_info); 82 } 83 #endif 84 } 85 86 // Note: Do not call this deregistration interface in the destructor of a global object or a singleton object, 87 // because fork_utils object may be destructed before them. 88 template <class T> DeregCallbacks(const T * obj)89 void DeregCallbacks(const T *obj) noexcept { 90 #if !defined(_WIN32) && !defined(BUILD_LITE) 91 try { 92 FORK_UTILS_LOG("Deregister fork callback info."); 93 for (auto iter = fork_callbacks_.begin(); iter != fork_callbacks_.end(); (void)++iter) { 94 if (iter->class_obj == obj) { 95 (void)fork_callbacks_.erase(iter); 96 break; 97 } 98 } 99 } catch (const std::exception &e) { 100 FORK_UTILS_LOG("Deregister fork callback info failed: %s", e.what()); 101 } catch (...) { 102 FORK_UTILS_LOG("Deregister fork callback info failed: Unknown exception."); 103 } 104 #endif 105 } 106 GetCallbacks()107 std::vector<fork_callback_info> GetCallbacks() { return fork_callbacks_; } 108 SetGilHoldBeforeFork(bool gil_hold_before_fork)109 void SetGilHoldBeforeFork(bool gil_hold_before_fork) { hold_gil_before_fork_ = gil_hold_before_fork; } 110 IsGilHoldBeforeFork()111 bool IsGilHoldBeforeFork() const { return hold_gil_before_fork_; } 112 SetGilState(int gil_state)113 void SetGilState(int gil_state) { gil_state_ = gil_state; } 114 GetGilState()115 int GetGilState() const { return gil_state_; } 116 117 void BeforeFork(); 118 119 void ParentAtFork(); 120 121 void ChildAtFork(); 122 123 private: 124 ForkUtils() = default; 125 ~ForkUtils() = default; 126 std::vector<fork_callback_info> fork_callbacks_; 127 // Record whether forked thread holds the gil lock when the fork occurs. 128 bool hold_gil_before_fork_ = false; 129 int gil_state_ = 0; 130 }; 131 } // namespace mindspore 132 #endif // MINDSPORE_CORE_UTILS_FORK_UTILS_H 133