• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_
17 #define MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_
18 
19 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
20 #include <semaphore.h>
21 #include <signal.h>  // for sig_atomic_t
22 #endif
23 #include <condition_variable>
24 #include <functional>
25 #include <memory>
26 #include <string>
27 #include <set>
28 #include "minddata/dataset/util/allocator.h"
29 #include "minddata/dataset/util/intrp_service.h"
30 #include "minddata/dataset/util/lock.h"
31 #include "minddata/dataset/util/services.h"
32 #include "minddata/dataset/util/status.h"
33 #include "minddata/dataset/util/task.h"
34 
35 namespace mindspore {
36 namespace dataset {
37 namespace thread {
38 using id = std::thread::id;
39 }  // namespace thread
40 
41 namespace this_thread {
get_id()42 inline thread::id get_id() { return std::this_thread::get_id(); }
43 }  // namespace this_thread
44 
45 class TaskManager : public Service {
46  public:
47   friend class Services;
48 
49   friend class TaskGroup;
50 
51   ~TaskManager() override;
52 
53   TaskManager(const TaskManager &) = delete;
54 
55   TaskManager &operator=(const TaskManager &) = delete;
56 
CreateInstance()57   static Status CreateInstance() {
58     std::call_once(init_instance_flag_, [&]() -> Status {
59       auto &svcManager = Services::GetInstance();
60       RETURN_IF_NOT_OK(svcManager.AddHook(&instance_));
61       return Status::OK();
62     });
63     return Status::OK();
64   }
65 
GetInstance()66   static TaskManager &GetInstance() noexcept { return *instance_; }
67 
68   Status DoServiceStart() override;
69 
70   Status DoServiceStop() override;
71 
72   // A public global interrupt flag for signal handlers
73   volatile sig_atomic_t global_interrupt_;
74 
75   // API
76   // This takes the same parameter as Task constructor. Take a look
77   // of the test-thread.cc for usage.
78   Status CreateAsyncTask(const std::string &my_name, const std::function<Status()> &f, TaskGroup *vg, Task **,
79                          int32_t operator_id = -1);
80 
81   // Same usage as boot thread group
82   Status join_all();
83 
84   void interrupt_all() noexcept;
85 
86   // Locate a particular Task.
87   static Task *FindMe();
88 
89   static void InterruptGroup(Task &);
90 
91   static Status GetMasterThreadRc();
92 
93   static void InterruptMaster(const Status &rc = Status::OK());
94 
WakeUpWatchDog()95   static void WakeUpWatchDog() {
96 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
97     TaskManager &tm = TaskManager::GetInstance();
98     (void)sem_post(&tm.sem_);
99 #endif
100   }
101 
102   void ReturnFreeTask(Task *p) noexcept;
103 
104   Status GetFreeTask(const std::string &my_name, const std::function<Status()> &f, Task **p, int32_t operator_id = -1);
105 
106   Status WatchDog();
107 
108  private:
109   static std::once_flag init_instance_flag_;
110   static TaskManager *instance_;
111   RWLock lru_lock_;
112   SpinLock free_lock_;
113   SpinLock tg_lock_;
114   std::shared_ptr<Task> master_;
115   List<Task> lru_;
116   List<Task> free_lst_;
117 #if !defined(_WIN32) && !defined(_WIN64) && !defined(__ANDROID__) && !defined(ANDROID) && !defined(__APPLE__)
118   sem_t sem_;
119 #endif
120   TaskGroup *watchdog_grp_;
121   std::set<TaskGroup *> grp_list_;
122   Task *watchdog_;
123 
124   TaskManager();
125 };
126 
127 // A group of related tasks.
128 class TaskGroup : public Service {
129  public:
130   friend class Task;
131   friend class TaskManager;
132 
133   Status CreateAsyncTask(const std::string &my_name, const std::function<Status()> &f, Task **pTask = nullptr,
134                          int32_t operator_id = -1);
135 
136   void interrupt_all() noexcept;
137 
138   Status join_all(Task::WaitFlag wf = Task::WaitFlag::kBlocking);
139 
size()140   int size() const noexcept { return grp_list_.count; }
141 
GetTask()142   List<Task> GetTask() const noexcept { return grp_list_; }
143 
DoServiceStart()144   Status DoServiceStart() override { return Status::OK(); }
145 
146   Status DoServiceStop() override;
147 
148   TaskGroup();
149 
150   ~TaskGroup() override;
151 
152   Status GetTaskErrorIfAny();
153 
154   std::shared_ptr<IntrpService> GetIntrpService();
155 
156  private:
157   Status rc_;
158   // Can't use rw_lock_ as we will lead to deadlatch. Create another mutex to serialize access to rc_.
159   std::mutex rc_mux_;
160   RWLock rw_lock_;
161   List<Task> grp_list_;
162   std::shared_ptr<IntrpService> intrp_svc_;
163 };
164 
165 namespace this_thread {
is_interrupted()166 inline bool is_interrupted() {
167   TaskManager &tm = TaskManager::GetInstance();
168   if (tm.global_interrupt_ == 1) {
169     return true;
170   }
171   Task *my_task = TaskManager::FindMe();
172   return my_task->Interrupted();
173 }
174 
is_master_thread()175 inline bool is_master_thread() {
176   Task *my_task = TaskManager::FindMe();
177   return my_task->IsMasterThread();
178 }
179 
GetInterruptStatus()180 inline Status GetInterruptStatus() {
181   Task *my_task = TaskManager::FindMe();
182   return my_task->GetInterruptStatus();
183 }
184 }  // namespace this_thread
185 
186 #define RETURN_IF_INTERRUPTED()                                            \
187   do {                                                                     \
188     if (mindspore::dataset::this_thread::is_interrupted()) {               \
189       return Task::OverrideInterruptRc(this_thread::GetInterruptStatus()); \
190     }                                                                      \
191   } while (false)
192 
193 }  // namespace dataset
194 }  // namespace mindspore
195 
196 #endif  // MINDSPORE_CCSRC_MINDDATA_DATASET_UTIL_TASK_MANAGER_H_
197