• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::enter;
2 use crate::unpark_mutex::UnparkMutex;
3 use futures_core::future::Future;
4 use futures_core::task::{Context, Poll};
5 use futures_task::{waker_ref, ArcWake};
6 use futures_task::{FutureObj, Spawn, SpawnError};
7 use futures_util::future::FutureExt;
8 use std::cmp;
9 use std::fmt;
10 use std::io;
11 use std::sync::atomic::{AtomicUsize, Ordering};
12 use std::sync::mpsc;
13 use std::sync::{Arc, Mutex};
14 use std::thread;
15 
16 /// A general-purpose thread pool for scheduling tasks that poll futures to
17 /// completion.
18 ///
19 /// The thread pool multiplexes any number of tasks onto a fixed number of
20 /// worker threads.
21 ///
22 /// This type is a clonable handle to the threadpool itself.
23 /// Cloning it will only create a new reference, not a new threadpool.
24 ///
25 /// This type is only available when the `thread-pool` feature of this
26 /// library is activated.
27 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
28 pub struct ThreadPool {
29     state: Arc<PoolState>,
30 }
31 
32 /// Thread pool configuration object.
33 ///
34 /// This type is only available when the `thread-pool` feature of this
35 /// library is activated.
36 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
37 pub struct ThreadPoolBuilder {
38     pool_size: usize,
39     stack_size: usize,
40     name_prefix: Option<String>,
41     after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
42     before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
43 }
44 
45 trait AssertSendSync: Send + Sync {}
46 impl AssertSendSync for ThreadPool {}
47 
48 struct PoolState {
49     tx: Mutex<mpsc::Sender<Message>>,
50     rx: Mutex<mpsc::Receiver<Message>>,
51     cnt: AtomicUsize,
52     size: usize,
53 }
54 
55 impl fmt::Debug for ThreadPool {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result56     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57         f.debug_struct("ThreadPool").field("size", &self.state.size).finish()
58     }
59 }
60 
61 impl fmt::Debug for ThreadPoolBuilder {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result62     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63         f.debug_struct("ThreadPoolBuilder")
64             .field("pool_size", &self.pool_size)
65             .field("name_prefix", &self.name_prefix)
66             .finish()
67     }
68 }
69 
70 enum Message {
71     Run(Task),
72     Close,
73 }
74 
75 impl ThreadPool {
76     /// Creates a new thread pool with the default configuration.
77     ///
78     /// See documentation for the methods in
79     /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
80     /// configuration.
new() -> Result<Self, io::Error>81     pub fn new() -> Result<Self, io::Error> {
82         ThreadPoolBuilder::new().create()
83     }
84 
85     /// Create a default thread pool configuration, which can then be customized.
86     ///
87     /// See documentation for the methods in
88     /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
89     /// configuration.
builder() -> ThreadPoolBuilder90     pub fn builder() -> ThreadPoolBuilder {
91         ThreadPoolBuilder::new()
92     }
93 
94     /// Spawns a future that will be run to completion.
95     ///
96     /// > **Note**: This method is similar to `Spawn::spawn_obj`, except that
97     /// >           it is guaranteed to always succeed.
spawn_obj_ok(&self, future: FutureObj<'static, ()>)98     pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) {
99         let task = Task {
100             future,
101             wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }),
102             exec: self.clone(),
103         };
104         self.state.send(Message::Run(task));
105     }
106 
107     /// Spawns a task that polls the given future with output `()` to
108     /// completion.
109     ///
110     /// ```
111     /// # {
112     /// use futures::executor::ThreadPool;
113     ///
114     /// let pool = ThreadPool::new().unwrap();
115     ///
116     /// let future = async { /* ... */ };
117     /// pool.spawn_ok(future);
118     /// # }
119     /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
120     /// ```
121     ///
122     /// > **Note**: This method is similar to `SpawnExt::spawn`, except that
123     /// >           it is guaranteed to always succeed.
spawn_ok<Fut>(&self, future: Fut) where Fut: Future<Output = ()> + Send + 'static,124     pub fn spawn_ok<Fut>(&self, future: Fut)
125     where
126         Fut: Future<Output = ()> + Send + 'static,
127     {
128         self.spawn_obj_ok(FutureObj::new(Box::new(future)))
129     }
130 }
131 
132 impl Spawn for ThreadPool {
spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError>133     fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
134         self.spawn_obj_ok(future);
135         Ok(())
136     }
137 }
138 
139 impl PoolState {
send(&self, msg: Message)140     fn send(&self, msg: Message) {
141         self.tx.lock().unwrap().send(msg).unwrap();
142     }
143 
work( &self, idx: usize, after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, )144     fn work(
145         &self,
146         idx: usize,
147         after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
148         before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
149     ) {
150         let _scope = enter().unwrap();
151         if let Some(after_start) = after_start {
152             after_start(idx);
153         }
154         loop {
155             let msg = self.rx.lock().unwrap().recv().unwrap();
156             match msg {
157                 Message::Run(task) => task.run(),
158                 Message::Close => break,
159             }
160         }
161         if let Some(before_stop) = before_stop {
162             before_stop(idx);
163         }
164     }
165 }
166 
167 impl Clone for ThreadPool {
clone(&self) -> Self168     fn clone(&self) -> Self {
169         self.state.cnt.fetch_add(1, Ordering::Relaxed);
170         Self { state: self.state.clone() }
171     }
172 }
173 
174 impl Drop for ThreadPool {
drop(&mut self)175     fn drop(&mut self) {
176         if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
177             for _ in 0..self.state.size {
178                 self.state.send(Message::Close);
179             }
180         }
181     }
182 }
183 
184 impl ThreadPoolBuilder {
185     /// Create a default thread pool configuration.
186     ///
187     /// See the other methods on this type for details on the defaults.
new() -> Self188     pub fn new() -> Self {
189         Self {
190             pool_size: cmp::max(1, num_cpus::get()),
191             stack_size: 0,
192             name_prefix: None,
193             after_start: None,
194             before_stop: None,
195         }
196     }
197 
198     /// Set size of a future ThreadPool
199     ///
200     /// The size of a thread pool is the number of worker threads spawned. By
201     /// default, this is equal to the number of CPU cores.
202     ///
203     /// # Panics
204     ///
205     /// Panics if `pool_size == 0`.
pool_size(&mut self, size: usize) -> &mut Self206     pub fn pool_size(&mut self, size: usize) -> &mut Self {
207         assert!(size > 0);
208         self.pool_size = size;
209         self
210     }
211 
212     /// Set stack size of threads in the pool, in bytes.
213     ///
214     /// By default, worker threads use Rust's standard stack size.
stack_size(&mut self, stack_size: usize) -> &mut Self215     pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
216         self.stack_size = stack_size;
217         self
218     }
219 
220     /// Set thread name prefix of a future ThreadPool.
221     ///
222     /// Thread name prefix is used for generating thread names. For example, if prefix is
223     /// `my-pool-`, then threads in the pool will get names like `my-pool-1` etc.
224     ///
225     /// By default, worker threads are assigned Rust's standard thread name.
name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self226     pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
227         self.name_prefix = Some(name_prefix.into());
228         self
229     }
230 
231     /// Execute the closure `f` immediately after each worker thread is started,
232     /// but before running any tasks on it.
233     ///
234     /// This hook is intended for bookkeeping and monitoring.
235     /// The closure `f` will be dropped after the `builder` is dropped
236     /// and all worker threads in the pool have executed it.
237     ///
238     /// The closure provided will receive an index corresponding to the worker
239     /// thread it's running on.
after_start<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,240     pub fn after_start<F>(&mut self, f: F) -> &mut Self
241     where
242         F: Fn(usize) + Send + Sync + 'static,
243     {
244         self.after_start = Some(Arc::new(f));
245         self
246     }
247 
248     /// Execute closure `f` just prior to shutting down each worker thread.
249     ///
250     /// This hook is intended for bookkeeping and monitoring.
251     /// The closure `f` will be dropped after the `builder` is dropped
252     /// and all threads in the pool have executed it.
253     ///
254     /// The closure provided will receive an index corresponding to the worker
255     /// thread it's running on.
before_stop<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,256     pub fn before_stop<F>(&mut self, f: F) -> &mut Self
257     where
258         F: Fn(usize) + Send + Sync + 'static,
259     {
260         self.before_stop = Some(Arc::new(f));
261         self
262     }
263 
264     /// Create a [`ThreadPool`](ThreadPool) with the given configuration.
create(&mut self) -> Result<ThreadPool, io::Error>265     pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
266         let (tx, rx) = mpsc::channel();
267         let pool = ThreadPool {
268             state: Arc::new(PoolState {
269                 tx: Mutex::new(tx),
270                 rx: Mutex::new(rx),
271                 cnt: AtomicUsize::new(1),
272                 size: self.pool_size,
273             }),
274         };
275 
276         for counter in 0..self.pool_size {
277             let state = pool.state.clone();
278             let after_start = self.after_start.clone();
279             let before_stop = self.before_stop.clone();
280             let mut thread_builder = thread::Builder::new();
281             if let Some(ref name_prefix) = self.name_prefix {
282                 thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter));
283             }
284             if self.stack_size > 0 {
285                 thread_builder = thread_builder.stack_size(self.stack_size);
286             }
287             thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
288         }
289         Ok(pool)
290     }
291 }
292 
293 impl Default for ThreadPoolBuilder {
default() -> Self294     fn default() -> Self {
295         Self::new()
296     }
297 }
298 
299 /// A task responsible for polling a future to completion.
300 struct Task {
301     future: FutureObj<'static, ()>,
302     exec: ThreadPool,
303     wake_handle: Arc<WakeHandle>,
304 }
305 
306 struct WakeHandle {
307     mutex: UnparkMutex<Task>,
308     exec: ThreadPool,
309 }
310 
311 impl Task {
312     /// Actually run the task (invoking `poll` on the future) on the current
313     /// thread.
run(self)314     fn run(self) {
315         let Self { mut future, wake_handle, mut exec } = self;
316         let waker = waker_ref(&wake_handle);
317         let mut cx = Context::from_waker(&waker);
318 
319         // Safety: The ownership of this `Task` object is evidence that
320         // we are in the `POLLING`/`REPOLL` state for the mutex.
321         unsafe {
322             wake_handle.mutex.start_poll();
323 
324             loop {
325                 let res = future.poll_unpin(&mut cx);
326                 match res {
327                     Poll::Pending => {}
328                     Poll::Ready(()) => return wake_handle.mutex.complete(),
329                 }
330                 let task = Self { future, wake_handle: wake_handle.clone(), exec };
331                 match wake_handle.mutex.wait(task) {
332                     Ok(()) => return, // we've waited
333                     Err(task) => {
334                         // someone's notified us
335                         future = task.future;
336                         exec = task.exec;
337                     }
338                 }
339             }
340         }
341     }
342 }
343 
344 impl fmt::Debug for Task {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result345     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
346         f.debug_struct("Task").field("contents", &"...").finish()
347     }
348 }
349 
350 impl ArcWake for WakeHandle {
wake_by_ref(arc_self: &Arc<Self>)351     fn wake_by_ref(arc_self: &Arc<Self>) {
352         if let Ok(task) = arc_self.mutex.notify() {
353             arc_self.exec.state.send(Message::Run(task))
354         }
355     }
356 }
357 
358 #[cfg(test)]
359 mod tests {
360     use super::*;
361     use std::sync::mpsc;
362 
363     #[test]
test_drop_after_start()364     fn test_drop_after_start() {
365         {
366             let (tx, rx) = mpsc::sync_channel(2);
367             let _cpu_pool = ThreadPoolBuilder::new()
368                 .pool_size(2)
369                 .after_start(move |_| tx.send(1).unwrap())
370                 .create()
371                 .unwrap();
372 
373             // After ThreadPoolBuilder is deconstructed, the tx should be dropped
374             // so that we can use rx as an iterator.
375             let count = rx.into_iter().count();
376             assert_eq!(count, 2);
377         }
378         std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
379     }
380 }
381