• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "include/common/thread_pool.h"
18 #include <exception>
19 #include "thread/threadlog.h"
20 #include "utils/log_adapter.h"
21 #include "utils/ms_exception.h"
22 
23 namespace mindspore {
24 namespace common {
25 constexpr size_t kYieldThreshold = 1000;
26 
ThreadPool()27 ThreadPool::ThreadPool() : max_thread_num_(std::thread::hardware_concurrency()) {}
28 
SyncRunLoop(const std::shared_ptr<ThreadContext> & context)29 void ThreadPool::SyncRunLoop(const std::shared_ptr<ThreadContext> &context) {
30   if (context == nullptr) {
31     return;
32   }
33   size_t yield_count = 0;
34   while (true) {
35     if (exit_run_) {
36       return;
37     }
38 
39     if (!context->task) {
40       MS_EXCEPTION_IF_NULL(context->cond_var);
41       ++yield_count;
42       if (yield_count > kYieldThreshold) {
43         yield_count = 0;
44         std::unique_lock<std::mutex> lock(context->mutex);
45         context->cond_var->wait(lock, [&context, this] { return context->task != nullptr || exit_run_; });
46       } else {
47         std::this_thread::yield();
48         continue;
49       }
50     }
51 
52     if (exit_run_) {
53       return;
54     }
55 
56     try {
57       auto &task = *(context->task);
58       task();
59     } catch (std::exception &e) {
60       MsException::Instance().SetException();
61     }
62     yield_count = 0;
63     context->task = nullptr;
64   }
65 }
66 #ifdef _WIN32
SetAffinity() const67 bool ThreadPool::SetAffinity() const { return false; }
68 #elif defined(BIND_CORE)
SetAffinity(const pthread_t & thread_id,cpu_set_t * cpu_set)69 bool ThreadPool::SetAffinity(const pthread_t &thread_id, cpu_set_t *cpu_set) {
70   if (cpu_set == nullptr) {
71     return false;
72   }
73 #ifdef __ANDROID__
74 #if __ANDROID_API__ >= 21
75   THREAD_INFO("thread: %d, mask: %lu", pthread_gettid_np(thread_id), cpu_set->__bits[0]);
76   int ret = sched_setaffinity(pthread_gettid_np(thread_id), sizeof(cpu_set_t), cpu_set);
77   if (ret != THREAD_OK) {
78     THREAD_ERROR("bind thread %d to cpu failed. ERROR %d", pthread_gettid_np(thread_id), ret);
79     return false;
80   }
81   return true;
82 #endif
83 #else
84 #if defined(__APPLE__)
85   THREAD_ERROR("not bind thread to apple's cpu.");
86   return false;
87 #else
88   int ret = pthread_setaffinity_np(thread_id, sizeof(cpu_set_t), cpu_set);
89   if (ret != THREAD_OK) {
90     THREAD_ERROR("set thread: %lu to cpu failed", thread_id);
91     return false;
92   }
93   return true;
94 #endif  // __APPLE__
95 #endif  // __ANDROID__
96   return false;
97 }
98 #endif  // __BIND_CORE__
99 
FreeScheduleThreads(const std::vector<int> & core_list)100 bool ThreadPool::FreeScheduleThreads(const std::vector<int> &core_list) {
101   if (core_list.empty()) {
102     return false;
103   }
104 #ifdef _WIN32
105   return false;
106 #elif defined(BIND_CORE)
107   for (const auto &sync_run_thread : sync_run_threads_) {
108     cpu_set_t mask;
109     CPU_ZERO(&mask);
110     for (auto core_id : core_list) {
111       CPU_SET(core_id, &mask);
112     }
113     if (!SetAffinity(sync_run_thread->native_handle(), &mask)) {
114       return false;
115     }
116   }
117   return true;
118 #endif  // BIND_CORE
119   return false;
120 }
121 
SetCpuAffinity(const std::vector<int> & core_list)122 bool ThreadPool::SetCpuAffinity(const std::vector<int> &core_list) {
123   if (core_list.empty()) {
124     return false;
125   }
126 #ifdef _WIN32
127   return false;
128 #elif defined(BIND_CORE)
129   for (size_t i = 0; i < sync_run_threads_.size(); i++) {
130     cpu_set_t mask;
131     CPU_ZERO(&mask);
132     CPU_SET(core_list[i % core_list.size()], &mask);
133     if (!SetAffinity(sync_run_threads_[i]->native_handle(), &mask)) {
134       return false;
135     }
136   }
137   return true;
138 #endif  // BIND_CORE
139   return false;
140 }
141 
SyncRun(const std::vector<Task> & tasks,const std::vector<int> & core_list)142 bool ThreadPool::SyncRun(const std::vector<Task> &tasks, const std::vector<int> &core_list) {
143   if (tasks.empty()) {
144     return true;
145   }
146   if (tasks.size() == 1) {
147     auto ret = tasks[0]();
148     return ret == SUCCESS;
149   }
150   std::unique_lock<std::mutex> lock(pool_mtx_);
151   exit_run_ = false;
152   size_t task_num = tasks.size();
153   size_t thread_num = sync_run_threads_.size();
154   if (thread_num < max_thread_num_ && thread_num < task_num) {
155     auto new_thread_num = max_thread_num_;
156     if (task_num < max_thread_num_) {
157       new_thread_num = task_num;
158     }
159     contexts_.resize(new_thread_num);
160     for (size_t i = thread_num; i < new_thread_num; ++i) {
161       contexts_[i] = std::make_shared<ThreadContext>();
162       sync_run_threads_.emplace_back(std::make_unique<std::thread>(&ThreadPool::SyncRunLoop, this, contexts_[i]));
163     }
164   }
165   if (contexts_.empty()) {
166     return true;
167   }
168   auto set_affinity_ret = SetCpuAffinity(core_list);
169   if (set_affinity_ret) {
170     MS_LOG(INFO) << "Set cpu affinity success.";
171   } else {
172     MS_LOG(DEBUG) << "Set cpu affinity failed.";
173   }
174   size_t used_thread_num = contexts_.size();
175   if (task_num < used_thread_num) {
176     used_thread_num = task_num;
177   }
178   bool running = true;
179   size_t task_index = 0;
180   while (running) {
181     running = false;
182     for (size_t i = 0; i < used_thread_num; ++i) {
183       MS_EXCEPTION_IF_NULL(contexts_[i]);
184       MS_EXCEPTION_IF_NULL(contexts_[i]->cond_var);
185       auto &task_run = contexts_[i]->task;
186       if (task_run) {
187         running = true;
188       } else if (task_index < task_num) {
189         std::lock_guard<std::mutex> task_lock(contexts_[i]->mutex);
190         contexts_[i]->task = &(tasks[task_index]);
191         contexts_[i]->cond_var->notify_one();
192         running = true;
193         ++task_index;
194       }
195     }
196     if (running) {
197       std::this_thread::yield();
198     }
199   }
200   auto free_schedule_threads_ret = FreeScheduleThreads(core_list);
201   if (free_schedule_threads_ret) {
202     MS_LOG(INFO) << "Free schedule threads success.";
203   } else {
204     MS_LOG(DEBUG) << "Free schedule threads failed.";
205   }
206   return true;
207 }
208 
GetInstance()209 ThreadPool &ThreadPool::GetInstance() {
210   static ThreadPool instance{};
211   return instance;
212 }
213 
ClearThreadPool()214 void ThreadPool::ClearThreadPool() {
215   std::lock_guard<std::mutex> sync_run_lock(pool_mtx_);
216   if (exit_run_) {
217     return;
218   }
219   exit_run_ = true;
220   for (auto &context : contexts_) {
221     MS_EXCEPTION_IF_NULL(context);
222     context->cond_var->notify_one();
223   }
224   for (auto &it : sync_run_threads_) {
225     MS_EXCEPTION_IF_NULL(it);
226     if (it->joinable()) {
227       it->join();
228     }
229   }
230   sync_run_threads_.clear();
231 }
232 
ChildAfterFork()233 void ThreadPool::ChildAfterFork() {
234   THREAD_INFO("common thread pool clear thread after fork in child process");
235   for (auto &context : contexts_) {
236     MS_EXCEPTION_IF_NULL(context);
237     if (context->cond_var != nullptr) {
238       (void)context->cond_var.release();
239       context->task = nullptr;
240     }
241   }
242   contexts_.clear();
243   for (auto &it : sync_run_threads_) {
244     if (it != nullptr) {
245       (void)it.release();
246     }
247   }
248   sync_run_threads_.clear();
249 }
250 
~ThreadPool()251 ThreadPool::~ThreadPool() {
252   try {
253     ClearThreadPool();
254   } catch (...) {
255     // exit
256   }
257 }
258 }  // namespace common
259 }  // namespace mindspore
260