1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_
18
19 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
20 #include <semaphore.h>
21 #include <signal.h> // for sig_atomic_t
22 #endif
23 #include <condition_variable>
24 #include <functional>
25 #include <memory>
26 #include <string>
27 #include <set>
28 #include "minddata/dataset/util/allocator.h"
29 #include "minddata/dataset/util/intrp_service.h"
30 #include "minddata/dataset/util/lock.h"
31 #include "minddata/dataset/util/services.h"
32 #include "minddata/dataset/util/status.h"
33 #include "minddata/dataset/util/task.h"
34
35 namespace mindspore {
36 namespace dataset {
37 namespace thread {
38 using id = std::thread::id;
39 } // namespace thread
40
41 namespace this_thread {
get_id()42 inline thread::id get_id() { return std::this_thread::get_id(); }
43 } // namespace this_thread
44
45 class TaskManager : public Service {
46 public:
47 friend class Services;
48
49 friend class TaskGroup;
50
51 ~TaskManager() override;
52
53 TaskManager(const TaskManager &) = delete;
54
55 TaskManager &operator=(const TaskManager &) = delete;
56
CreateInstance()57 static Status CreateInstance() {
58 std::call_once(init_instance_flag_, [&]() -> Status {
59 auto &svcManager = Services::GetInstance();
60 RETURN_IF_NOT_OK(svcManager.AddHook(&instance_));
61 return Status::OK();
62 });
63 return Status::OK();
64 }
65
GetInstance()66 static TaskManager &GetInstance() noexcept { return *instance_; }
67
68 Status DoServiceStart() override;
69
70 Status DoServiceStop() override;
71
72 // A public global interrupt flag for signal handlers
73 volatile sig_atomic_t global_interrupt_;
74
75 // API
76 // This takes the same parameter as Task constructor. Take a look
77 // of the test-thread.cc for usage.
78 Status CreateAsyncTask(const std::string &my_name, const std::function<Status()> &f, TaskGroup *vg, Task **,
79 int32_t operator_id = -1);
80
81 // Same usage as boot thread group
82 Status join_all();
83
84 void interrupt_all() noexcept;
85
86 // Locate a particular Task.
87 static Task *FindMe();
88
89 static void InterruptGroup(Task &);
90
91 static Status GetMasterThreadRc();
92
93 static void InterruptMaster(const Status &rc = Status::OK());
94
WakeUpWatchDog()95 static void WakeUpWatchDog() {
96 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
97 TaskManager &tm = TaskManager::GetInstance();
98 (void)sem_post(&tm.sem_);
99 #endif
100 }
101
102 void ReturnFreeTask(Task *p) noexcept;
103
104 Status GetFreeTask(const std::string &my_name, const std::function<Status()> &f, Task **p, int32_t operator_id = -1);
105
106 Status WatchDog();
107
108 private:
109 static std::once_flag init_instance_flag_;
110 static TaskManager *instance_;
111 RWLock lru_lock_;
112 SpinLock free_lock_;
113 SpinLock tg_lock_;
114 std::shared_ptr<Task> master_;
115 List<Task> lru_;
116 List<Task> free_lst_;
117 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
118 sem_t sem_;
119 #endif
120 TaskGroup *watchdog_grp_;
121 std::set<TaskGroup *> grp_list_;
122 Task *watchdog_;
123
124 TaskManager();
125 };
126
127 // A group of related tasks.
128 class TaskGroup : public Service {
129 public:
130 friend class Task;
131 friend class TaskManager;
132
133 Status CreateAsyncTask(const std::string &my_name, const std::function<Status()> &f, Task **pTask = nullptr,
134 int32_t operator_id = -1);
135
136 void interrupt_all() noexcept;
137
138 Status join_all(Task::WaitFlag wf = Task::WaitFlag::kBlocking);
139
size()140 int size() const noexcept { return grp_list_.count; }
141
GetTask()142 List<Task> GetTask() const noexcept { return grp_list_; }
143
DoServiceStart()144 Status DoServiceStart() override { return Status::OK(); }
145
146 Status DoServiceStop() override;
147
148 TaskGroup();
149
150 ~TaskGroup() override;
151
152 Status GetTaskErrorIfAny();
153
154 std::shared_ptr<IntrpService> GetIntrpService();
155
156 private:
157 Status rc_;
158 // Can't use rw_lock_ as we will lead to deadlatch. Create another mutex to serialize access to rc_.
159 std::mutex rc_mux_;
160 RWLock rw_lock_;
161 List<Task> grp_list_;
162 std::shared_ptr<IntrpService> intrp_svc_;
163 };
164
165 namespace this_thread {
is_interrupted()166 inline bool is_interrupted() {
167 TaskManager &tm = TaskManager::GetInstance();
168 if (tm.global_interrupt_ == 1) {
169 return true;
170 }
171 Task *my_task = TaskManager::FindMe();
172 return my_task->Interrupted();
173 }
174
is_master_thread()175 inline bool is_master_thread() {
176 Task *my_task = TaskManager::FindMe();
177 return my_task->IsMasterThread();
178 }
179
GetInterruptStatus()180 inline Status GetInterruptStatus() {
181 Task *my_task = TaskManager::FindMe();
182 return my_task->GetInterruptStatus();
183 }
184 } // namespace this_thread
185
186 #define RETURN_IF_INTERRUPTED() \
187 do { \
188 if (mindspore::dataset::this_thread::is_interrupted()) { \
189 return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \
190 } \
191 } while (false)
192
193 } // namespace dataset
194 } // namespace mindspore
195
196 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_
197