1 // This file is part of Eigen, a lightweight C++ template library 2 // for linear algebra. 3 // 4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com> 5 // 6 // This Source Code Form is subject to the terms of the Mozilla 7 // Public License v. 2.0. If a copy of the MPL was not distributed 8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. 9 10 #ifndef EIGEN_CXX11_THREADPOOL_SIMPLE_THREAD_POOL_H 11 #define EIGEN_CXX11_THREADPOOL_SIMPLE_THREAD_POOL_H 12 13 namespace Eigen { 14 15 // The implementation of the ThreadPool type ensures that the Schedule method 16 // runs the functions it is provided in FIFO order when the scheduling is done 17 // by a single thread. 18 // Environment provides a way to create threads and also allows to intercept 19 // task submission and execution. 20 template <typename Environment> 21 class SimpleThreadPoolTempl : public ThreadPoolInterface { 22 public: 23 // Construct a pool that contains "num_threads" threads. 24 explicit SimpleThreadPoolTempl(int num_threads, Environment env = Environment()) env_(env)25 : env_(env), threads_(num_threads), waiters_(num_threads) { 26 for (int i = 0; i < num_threads; i++) { 27 threads_.push_back(env.CreateThread([this, i]() { WorkerLoop(i); })); 28 } 29 } 30 31 // Wait until all scheduled work has finished and then destroy the 32 // set of threads. ~SimpleThreadPoolTempl()33 ~SimpleThreadPoolTempl() { 34 { 35 // Wait for all work to get done. 36 std::unique_lock<std::mutex> l(mu_); 37 while (!pending_.empty()) { 38 empty_.wait(l); 39 } 40 exiting_ = true; 41 42 // Wakeup all waiters. 43 for (auto w : waiters_) { 44 w->ready = true; 45 w->task.f = nullptr; 46 w->cv.notify_one(); 47 } 48 } 49 50 // Wait for threads to finish. 51 for (auto t : threads_) { 52 delete t; 53 } 54 } 55 56 // Schedule fn() for execution in the pool of threads. The functions are 57 // executed in the order in which they are scheduled. Schedule(std::function<void ()> fn)58 void Schedule(std::function<void()> fn) final { 59 Task t = env_.CreateTask(std::move(fn)); 60 std::unique_lock<std::mutex> l(mu_); 61 if (waiters_.empty()) { 62 pending_.push_back(std::move(t)); 63 } else { 64 Waiter* w = waiters_.back(); 65 waiters_.pop_back(); 66 w->ready = true; 67 w->task = std::move(t); 68 w->cv.notify_one(); 69 } 70 } 71 NumThreads()72 int NumThreads() const final { 73 return static_cast<int>(threads_.size()); 74 } 75 CurrentThreadId()76 int CurrentThreadId() const final { 77 const PerThread* pt = this->GetPerThread(); 78 if (pt->pool == this) { 79 return pt->thread_id; 80 } else { 81 return -1; 82 } 83 } 84 85 protected: WorkerLoop(int thread_id)86 void WorkerLoop(int thread_id) { 87 std::unique_lock<std::mutex> l(mu_); 88 PerThread* pt = GetPerThread(); 89 pt->pool = this; 90 pt->thread_id = thread_id; 91 Waiter w; 92 Task t; 93 while (!exiting_) { 94 if (pending_.empty()) { 95 // Wait for work to be assigned to me 96 w.ready = false; 97 waiters_.push_back(&w); 98 while (!w.ready) { 99 w.cv.wait(l); 100 } 101 t = w.task; 102 w.task.f = nullptr; 103 } else { 104 // Pick up pending work 105 t = std::move(pending_.front()); 106 pending_.pop_front(); 107 if (pending_.empty()) { 108 empty_.notify_all(); 109 } 110 } 111 if (t.f) { 112 mu_.unlock(); 113 env_.ExecuteTask(t); 114 t.f = nullptr; 115 mu_.lock(); 116 } 117 } 118 } 119 120 private: 121 typedef typename Environment::Task Task; 122 typedef typename Environment::EnvThread Thread; 123 124 struct Waiter { 125 std::condition_variable cv; 126 Task task; 127 bool ready; 128 }; 129 130 struct PerThread { PerThreadPerThread131 constexpr PerThread() : pool(NULL), thread_id(-1) { } 132 SimpleThreadPoolTempl* pool; // Parent pool, or null for normal threads. 133 int thread_id; // Worker thread index in pool. 134 }; 135 136 Environment env_; 137 std::mutex mu_; 138 MaxSizeVector<Thread*> threads_; // All threads 139 MaxSizeVector<Waiter*> waiters_; // Stack of waiting threads. 140 std::deque<Task> pending_; // Queue of pending work 141 std::condition_variable empty_; // Signaled on pending_.empty() 142 bool exiting_ = false; 143 GetPerThread()144 PerThread* GetPerThread() const { 145 EIGEN_THREAD_LOCAL PerThread per_thread; 146 return &per_thread; 147 } 148 }; 149 150 typedef SimpleThreadPoolTempl<StlThreadEnvironment> SimpleThreadPool; 151 152 } // namespace Eigen 153 154 #endif // EIGEN_CXX11_THREADPOOL_SIMPLE_THREAD_POOL_H 155