• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <boost/asio/execution.hpp>
2 #include <boost/asio/static_thread_pool.hpp>
3 #include <algorithm>
4 #include <condition_variable>
5 #include <memory>
6 #include <mutex>
7 #include <queue>
8 #include <thread>
9 #include <numeric>
10 
11 using boost::asio::static_thread_pool;
12 namespace execution = boost::asio::execution;
13 
14 // A fixed-size thread pool used to implement fork/join semantics. Functions
15 // are scheduled using a simple FIFO queue. Implementing work stealing, or
16 // using a queue based on atomic operations, are left as tasks for the reader.
17 class fork_join_pool
18 {
19 public:
20   // The constructor starts a thread pool with the specified number of threads.
21   // Note that the thread_count is not a fixed limit on the pool's concurrency.
22   // Additional threads may temporarily be added to the pool if they join a
23   // fork_executor.
fork_join_pool(std::size_t thread_count=std::max (std::thread::hardware_concurrency (),1u)* 2)24   explicit fork_join_pool(
25       std::size_t thread_count = std::max(std::thread::hardware_concurrency(), 1u) * 2)
26     : use_count_(1),
27       threads_(thread_count)
28   {
29     try
30     {
31       // Ask each thread in the pool to dequeue and execute functions until
32       // it is time to shut down, i.e. the use count is zero.
33       for (thread_count_ = 0; thread_count_ < thread_count; ++thread_count_)
34       {
35         execution::execute(
36             threads_.executor(),
37             [this]
38             {
39               std::unique_lock<std::mutex> lock(mutex_);
40               while (use_count_ > 0)
41                 if (!execute_next(lock))
42                   condition_.wait(lock);
43             });
44       }
45     }
46     catch (...)
47     {
48       stop_threads();
49       threads_.wait();
50       throw;
51     }
52   }
53 
54   // The destructor waits for the pool to finish executing functions.
~fork_join_pool()55   ~fork_join_pool()
56   {
57     stop_threads();
58     threads_.wait();
59   }
60 
61 private:
62   friend class fork_executor;
63 
64   // The base for all functions that are queued in the pool.
65   struct function_base
66   {
67     std::shared_ptr<std::size_t> work_count_;
68     void (*execute_)(std::shared_ptr<function_base>& p);
69   };
70 
71   // Execute the next function from the queue, if any. Returns true if a
72   // function was executed, and false if the queue was empty.
execute_next(std::unique_lock<std::mutex> & lock)73   bool execute_next(std::unique_lock<std::mutex>& lock)
74   {
75     if (queue_.empty())
76       return false;
77     auto p(queue_.front());
78     queue_.pop();
79     lock.unlock();
80     execute(lock, p);
81     return true;
82   }
83 
84   // Execute a function and decrement the outstanding work.
execute(std::unique_lock<std::mutex> & lock,std::shared_ptr<function_base> & p)85   void execute(std::unique_lock<std::mutex>& lock,
86       std::shared_ptr<function_base>& p)
87   {
88     std::shared_ptr<std::size_t> work_count(std::move(p->work_count_));
89     try
90     {
91       p->execute_(p);
92       lock.lock();
93       do_work_finished(work_count);
94     }
95     catch (...)
96     {
97       lock.lock();
98       do_work_finished(work_count);
99       throw;
100     }
101   }
102 
103   // Increment outstanding work.
do_work_started(const std::shared_ptr<std::size_t> & work_count)104   void do_work_started(const std::shared_ptr<std::size_t>& work_count) noexcept
105   {
106     if (++(*work_count) == 1)
107       ++use_count_;
108   }
109 
110   // Decrement outstanding work. Notify waiting threads if we run out.
do_work_finished(const std::shared_ptr<std::size_t> & work_count)111   void do_work_finished(const std::shared_ptr<std::size_t>& work_count) noexcept
112   {
113     if (--(*work_count) == 0)
114     {
115       --use_count_;
116       condition_.notify_all();
117     }
118   }
119 
120   // Dispatch a function, executing it immediately if the queue is already
121   // loaded. Otherwise adds the function to the queue and wakes a thread.
do_execute(std::shared_ptr<function_base> p,const std::shared_ptr<std::size_t> & work_count)122   void do_execute(std::shared_ptr<function_base> p,
123       const std::shared_ptr<std::size_t>& work_count)
124   {
125     std::unique_lock<std::mutex> lock(mutex_);
126     if (queue_.size() > thread_count_ * 16)
127     {
128       do_work_started(work_count);
129       lock.unlock();
130       execute(lock, p);
131     }
132     else
133     {
134       queue_.push(p);
135       do_work_started(work_count);
136       condition_.notify_one();
137     }
138   }
139 
140   // Ask all threads to shut down.
stop_threads()141   void stop_threads()
142   {
143     std::lock_guard<std::mutex> lock(mutex_);
144     --use_count_;
145     condition_.notify_all();
146   }
147 
148   std::mutex mutex_;
149   std::condition_variable condition_;
150   std::queue<std::shared_ptr<function_base>> queue_;
151   std::size_t use_count_;
152   std::size_t thread_count_;
153   static_thread_pool threads_;
154 };
155 
156 // A class that satisfies the Executor requirements. Every function or piece of
157 // work associated with a fork_executor is part of a single, joinable group.
158 class fork_executor
159 {
160 public:
fork_executor(fork_join_pool & ctx)161   fork_executor(fork_join_pool& ctx)
162     : context_(ctx),
163       work_count_(std::make_shared<std::size_t>(0))
164   {
165   }
166 
query(execution::context_t) const167   fork_join_pool& query(execution::context_t) const noexcept
168   {
169     return context_;
170   }
171 
172   template <class Func>
execute(Func f) const173   void execute(Func f) const
174   {
175     auto p(std::make_shared<function<Func>>(std::move(f), work_count_));
176     context_.do_execute(p, work_count_);
177   }
178 
operator ==(const fork_executor & a,const fork_executor & b)179   friend bool operator==(const fork_executor& a,
180       const fork_executor& b) noexcept
181   {
182     return a.work_count_ == b.work_count_;
183   }
184 
operator !=(const fork_executor & a,const fork_executor & b)185   friend bool operator!=(const fork_executor& a,
186       const fork_executor& b) noexcept
187   {
188     return a.work_count_ != b.work_count_;
189   }
190 
191   // Block until all work associated with the executor is complete. While it is
192   // waiting, the thread may be borrowed to execute functions from the queue.
join() const193   void join() const
194   {
195     std::unique_lock<std::mutex> lock(context_.mutex_);
196     while (*work_count_ > 0)
197       if (!context_.execute_next(lock))
198         context_.condition_.wait(lock);
199   }
200 
201 private:
202   template <class Func>
203   struct function : fork_join_pool::function_base
204   {
functionfork_executor::function205     explicit function(Func f, const std::shared_ptr<std::size_t>& w)
206       : function_(std::move(f))
207     {
208       work_count_ = w;
209       execute_ = [](std::shared_ptr<fork_join_pool::function_base>& p)
210       {
211         Func tmp(std::move(static_cast<function*>(p.get())->function_));
212         p.reset();
213         tmp();
214       };
215     }
216 
217     Func function_;
218   };
219 
220   fork_join_pool& context_;
221   std::shared_ptr<std::size_t> work_count_;
222 };
223 
224 // Helper class to automatically join a fork_executor when exiting a scope.
225 class join_guard
226 {
227 public:
join_guard(const fork_executor & ex)228   explicit join_guard(const fork_executor& ex) : ex_(ex) {}
229   join_guard(const join_guard&) = delete;
230   join_guard(join_guard&&) = delete;
~join_guard()231   ~join_guard() { ex_.join(); }
232 
233 private:
234   fork_executor ex_;
235 };
236 
237 //------------------------------------------------------------------------------
238 
239 #include <algorithm>
240 #include <iostream>
241 #include <random>
242 #include <vector>
243 
244 fork_join_pool pool;
245 
246 template <class Iterator>
fork_join_sort(Iterator begin,Iterator end)247 void fork_join_sort(Iterator begin, Iterator end)
248 {
249   std::size_t n = end - begin;
250   if (n > 32768)
251   {
252     {
253       fork_executor fork(pool);
254       join_guard join(fork);
255       execution::execute(fork, [=]{ fork_join_sort(begin, begin + n / 2); });
256       execution::execute(fork, [=]{ fork_join_sort(begin + n / 2, end); });
257     }
258     std::inplace_merge(begin, begin + n / 2, end);
259   }
260   else
261   {
262     std::sort(begin, end);
263   }
264 }
265 
main(int argc,char * argv[])266 int main(int argc, char* argv[])
267 {
268   if (argc != 2)
269   {
270     std::cerr << "Usage: fork_join <size>\n";
271     return 1;
272   }
273 
274   std::vector<double> vec(std::atoll(argv[1]));
275   std::iota(vec.begin(), vec.end(), 0);
276 
277   std::random_device rd;
278   std::mt19937 g(rd());
279   std::shuffle(vec.begin(), vec.end(), g);
280 
281   std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();
282 
283   fork_join_sort(vec.begin(), vec.end());
284 
285   std::chrono::steady_clock::duration elapsed = std::chrono::steady_clock::now() - start;
286 
287   std::cout << "sort took ";
288   std::cout << std::chrono::duration_cast<std::chrono::microseconds>(elapsed).count();
289   std::cout << " microseconds" << std::endl;
290 }
291