1 #include <boost/asio/execution.hpp>
2 #include <boost/asio/static_thread_pool.hpp>
3 #include <algorithm>
4 #include <condition_variable>
5 #include <memory>
6 #include <mutex>
7 #include <queue>
8 #include <thread>
9 #include <numeric>
10
11 using boost::asio::static_thread_pool;
12 namespace execution = boost::asio::execution;
13
14 // A fixed-size thread pool used to implement fork/join semantics. Functions
15 // are scheduled using a simple FIFO queue. Implementing work stealing, or
16 // using a queue based on atomic operations, are left as tasks for the reader.
17 class fork_join_pool
18 {
19 public:
20 // The constructor starts a thread pool with the specified number of threads.
21 // Note that the thread_count is not a fixed limit on the pool's concurrency.
22 // Additional threads may temporarily be added to the pool if they join a
23 // fork_executor.
fork_join_pool(std::size_t thread_count=std::max (std::thread::hardware_concurrency (),1u)* 2)24 explicit fork_join_pool(
25 std::size_t thread_count = std::max(std::thread::hardware_concurrency(), 1u) * 2)
26 : use_count_(1),
27 threads_(thread_count)
28 {
29 try
30 {
31 // Ask each thread in the pool to dequeue and execute functions until
32 // it is time to shut down, i.e. the use count is zero.
33 for (thread_count_ = 0; thread_count_ < thread_count; ++thread_count_)
34 {
35 execution::execute(
36 threads_.executor(),
37 [this]
38 {
39 std::unique_lock<std::mutex> lock(mutex_);
40 while (use_count_ > 0)
41 if (!execute_next(lock))
42 condition_.wait(lock);
43 });
44 }
45 }
46 catch (...)
47 {
48 stop_threads();
49 threads_.wait();
50 throw;
51 }
52 }
53
54 // The destructor waits for the pool to finish executing functions.
~fork_join_pool()55 ~fork_join_pool()
56 {
57 stop_threads();
58 threads_.wait();
59 }
60
61 private:
62 friend class fork_executor;
63
64 // The base for all functions that are queued in the pool.
65 struct function_base
66 {
67 std::shared_ptr<std::size_t> work_count_;
68 void (*execute_)(std::shared_ptr<function_base>& p);
69 };
70
71 // Execute the next function from the queue, if any. Returns true if a
72 // function was executed, and false if the queue was empty.
execute_next(std::unique_lock<std::mutex> & lock)73 bool execute_next(std::unique_lock<std::mutex>& lock)
74 {
75 if (queue_.empty())
76 return false;
77 auto p(queue_.front());
78 queue_.pop();
79 lock.unlock();
80 execute(lock, p);
81 return true;
82 }
83
84 // Execute a function and decrement the outstanding work.
execute(std::unique_lock<std::mutex> & lock,std::shared_ptr<function_base> & p)85 void execute(std::unique_lock<std::mutex>& lock,
86 std::shared_ptr<function_base>& p)
87 {
88 std::shared_ptr<std::size_t> work_count(std::move(p->work_count_));
89 try
90 {
91 p->execute_(p);
92 lock.lock();
93 do_work_finished(work_count);
94 }
95 catch (...)
96 {
97 lock.lock();
98 do_work_finished(work_count);
99 throw;
100 }
101 }
102
103 // Increment outstanding work.
do_work_started(const std::shared_ptr<std::size_t> & work_count)104 void do_work_started(const std::shared_ptr<std::size_t>& work_count) noexcept
105 {
106 if (++(*work_count) == 1)
107 ++use_count_;
108 }
109
110 // Decrement outstanding work. Notify waiting threads if we run out.
do_work_finished(const std::shared_ptr<std::size_t> & work_count)111 void do_work_finished(const std::shared_ptr<std::size_t>& work_count) noexcept
112 {
113 if (--(*work_count) == 0)
114 {
115 --use_count_;
116 condition_.notify_all();
117 }
118 }
119
120 // Dispatch a function, executing it immediately if the queue is already
121 // loaded. Otherwise adds the function to the queue and wakes a thread.
do_execute(std::shared_ptr<function_base> p,const std::shared_ptr<std::size_t> & work_count)122 void do_execute(std::shared_ptr<function_base> p,
123 const std::shared_ptr<std::size_t>& work_count)
124 {
125 std::unique_lock<std::mutex> lock(mutex_);
126 if (queue_.size() > thread_count_ * 16)
127 {
128 do_work_started(work_count);
129 lock.unlock();
130 execute(lock, p);
131 }
132 else
133 {
134 queue_.push(p);
135 do_work_started(work_count);
136 condition_.notify_one();
137 }
138 }
139
140 // Ask all threads to shut down.
stop_threads()141 void stop_threads()
142 {
143 std::lock_guard<std::mutex> lock(mutex_);
144 --use_count_;
145 condition_.notify_all();
146 }
147
148 std::mutex mutex_;
149 std::condition_variable condition_;
150 std::queue<std::shared_ptr<function_base>> queue_;
151 std::size_t use_count_;
152 std::size_t thread_count_;
153 static_thread_pool threads_;
154 };
155
156 // A class that satisfies the Executor requirements. Every function or piece of
157 // work associated with a fork_executor is part of a single, joinable group.
158 class fork_executor
159 {
160 public:
fork_executor(fork_join_pool & ctx)161 fork_executor(fork_join_pool& ctx)
162 : context_(ctx),
163 work_count_(std::make_shared<std::size_t>(0))
164 {
165 }
166
query(execution::context_t) const167 fork_join_pool& query(execution::context_t) const noexcept
168 {
169 return context_;
170 }
171
172 template <class Func>
execute(Func f) const173 void execute(Func f) const
174 {
175 auto p(std::make_shared<function<Func>>(std::move(f), work_count_));
176 context_.do_execute(p, work_count_);
177 }
178
operator ==(const fork_executor & a,const fork_executor & b)179 friend bool operator==(const fork_executor& a,
180 const fork_executor& b) noexcept
181 {
182 return a.work_count_ == b.work_count_;
183 }
184
operator !=(const fork_executor & a,const fork_executor & b)185 friend bool operator!=(const fork_executor& a,
186 const fork_executor& b) noexcept
187 {
188 return a.work_count_ != b.work_count_;
189 }
190
191 // Block until all work associated with the executor is complete. While it is
192 // waiting, the thread may be borrowed to execute functions from the queue.
join() const193 void join() const
194 {
195 std::unique_lock<std::mutex> lock(context_.mutex_);
196 while (*work_count_ > 0)
197 if (!context_.execute_next(lock))
198 context_.condition_.wait(lock);
199 }
200
201 private:
202 template <class Func>
203 struct function : fork_join_pool::function_base
204 {
functionfork_executor::function205 explicit function(Func f, const std::shared_ptr<std::size_t>& w)
206 : function_(std::move(f))
207 {
208 work_count_ = w;
209 execute_ = [](std::shared_ptr<fork_join_pool::function_base>& p)
210 {
211 Func tmp(std::move(static_cast<function*>(p.get())->function_));
212 p.reset();
213 tmp();
214 };
215 }
216
217 Func function_;
218 };
219
220 fork_join_pool& context_;
221 std::shared_ptr<std::size_t> work_count_;
222 };
223
224 // Helper class to automatically join a fork_executor when exiting a scope.
225 class join_guard
226 {
227 public:
join_guard(const fork_executor & ex)228 explicit join_guard(const fork_executor& ex) : ex_(ex) {}
229 join_guard(const join_guard&) = delete;
230 join_guard(join_guard&&) = delete;
~join_guard()231 ~join_guard() { ex_.join(); }
232
233 private:
234 fork_executor ex_;
235 };
236
237 //------------------------------------------------------------------------------
238
239 #include <algorithm>
240 #include <iostream>
241 #include <random>
242 #include <vector>
243
244 fork_join_pool pool;
245
246 template <class Iterator>
fork_join_sort(Iterator begin,Iterator end)247 void fork_join_sort(Iterator begin, Iterator end)
248 {
249 std::size_t n = end - begin;
250 if (n > 32768)
251 {
252 {
253 fork_executor fork(pool);
254 join_guard join(fork);
255 execution::execute(fork, [=]{ fork_join_sort(begin, begin + n / 2); });
256 execution::execute(fork, [=]{ fork_join_sort(begin + n / 2, end); });
257 }
258 std::inplace_merge(begin, begin + n / 2, end);
259 }
260 else
261 {
262 std::sort(begin, end);
263 }
264 }
265
main(int argc,char * argv[])266 int main(int argc, char* argv[])
267 {
268 if (argc != 2)
269 {
270 std::cerr << "Usage: fork_join <size>\n";
271 return 1;
272 }
273
274 std::vector<double> vec(std::atoll(argv[1]));
275 std::iota(vec.begin(), vec.end(), 0);
276
277 std::random_device rd;
278 std::mt19937 g(rd());
279 std::shuffle(vec.begin(), vec.end(), g);
280
281 std::chrono::steady_clock::time_point start = std::chrono::steady_clock::now();
282
283 fork_join_sort(vec.begin(), vec.end());
284
285 std::chrono::steady_clock::duration elapsed = std::chrono::steady_clock::now() - start;
286
287 std::cout << "sort took ";
288 std::cout << std::chrono::duration_cast<std::chrono::microseconds>(elapsed).count();
289 std::cout << " microseconds" << std::endl;
290 }
291