• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use crate::enter;
2 use crate::unpark_mutex::UnparkMutex;
3 use futures_core::future::Future;
4 use futures_core::task::{Context, Poll};
5 use futures_task::{waker_ref, ArcWake};
6 use futures_task::{FutureObj, Spawn, SpawnError};
7 use futures_util::future::FutureExt;
8 use std::cmp;
9 use std::fmt;
10 use std::io;
11 use std::sync::atomic::{AtomicUsize, Ordering};
12 use std::sync::mpsc;
13 use std::sync::{Arc, Mutex};
14 use std::thread;
15 
16 /// A general-purpose thread pool for scheduling tasks that poll futures to
17 /// completion.
18 ///
19 /// The thread pool multiplexes any number of tasks onto a fixed number of
20 /// worker threads.
21 ///
22 /// This type is a clonable handle to the threadpool itself.
23 /// Cloning it will only create a new reference, not a new threadpool.
24 ///
25 /// This type is only available when the `thread-pool` feature of this
26 /// library is activated.
27 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
28 pub struct ThreadPool {
29     state: Arc<PoolState>,
30 }
31 
32 /// Thread pool configuration object.
33 ///
34 /// This type is only available when the `thread-pool` feature of this
35 /// library is activated.
36 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
37 pub struct ThreadPoolBuilder {
38     pool_size: usize,
39     stack_size: usize,
40     name_prefix: Option<String>,
41     after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
42     before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
43 }
44 
45 trait AssertSendSync: Send + Sync {}
46 impl AssertSendSync for ThreadPool {}
47 
48 struct PoolState {
49     tx: Mutex<mpsc::Sender<Message>>,
50     rx: Mutex<mpsc::Receiver<Message>>,
51     cnt: AtomicUsize,
52     size: usize,
53 }
54 
55 impl fmt::Debug for ThreadPool {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result56     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57         f.debug_struct("ThreadPool").field("size", &self.state.size).finish()
58     }
59 }
60 
61 impl fmt::Debug for ThreadPoolBuilder {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result62     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63         f.debug_struct("ThreadPoolBuilder")
64             .field("pool_size", &self.pool_size)
65             .field("name_prefix", &self.name_prefix)
66             .finish()
67     }
68 }
69 
70 enum Message {
71     Run(Task),
72     Close,
73 }
74 
75 impl ThreadPool {
76     /// Creates a new thread pool with the default configuration.
77     ///
78     /// See documentation for the methods in
79     /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
80     /// configuration.
new() -> Result<Self, io::Error>81     pub fn new() -> Result<Self, io::Error> {
82         ThreadPoolBuilder::new().create()
83     }
84 
85     /// Create a default thread pool configuration, which can then be customized.
86     ///
87     /// See documentation for the methods in
88     /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
89     /// configuration.
builder() -> ThreadPoolBuilder90     pub fn builder() -> ThreadPoolBuilder {
91         ThreadPoolBuilder::new()
92     }
93 
94     /// Spawns a future that will be run to completion.
95     ///
96     /// > **Note**: This method is similar to `Spawn::spawn_obj`, except that
97     /// >           it is guaranteed to always succeed.
spawn_obj_ok(&self, future: FutureObj<'static, ()>)98     pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) {
99         let task = Task {
100             future,
101             wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }),
102             exec: self.clone(),
103         };
104         self.state.send(Message::Run(task));
105     }
106 
107     /// Spawns a task that polls the given future with output `()` to
108     /// completion.
109     ///
110     /// ```
111     /// use futures::executor::ThreadPool;
112     ///
113     /// let pool = ThreadPool::new().unwrap();
114     ///
115     /// let future = async { /* ... */ };
116     /// pool.spawn_ok(future);
117     /// ```
118     ///
119     /// > **Note**: This method is similar to `SpawnExt::spawn`, except that
120     /// >           it is guaranteed to always succeed.
spawn_ok<Fut>(&self, future: Fut) where Fut: Future<Output = ()> + Send + 'static,121     pub fn spawn_ok<Fut>(&self, future: Fut)
122     where
123         Fut: Future<Output = ()> + Send + 'static,
124     {
125         self.spawn_obj_ok(FutureObj::new(Box::new(future)))
126     }
127 }
128 
129 impl Spawn for ThreadPool {
spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError>130     fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
131         self.spawn_obj_ok(future);
132         Ok(())
133     }
134 }
135 
136 impl PoolState {
send(&self, msg: Message)137     fn send(&self, msg: Message) {
138         self.tx.lock().unwrap().send(msg).unwrap();
139     }
140 
work( &self, idx: usize, after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, )141     fn work(
142         &self,
143         idx: usize,
144         after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
145         before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
146     ) {
147         let _scope = enter().unwrap();
148         if let Some(after_start) = after_start {
149             after_start(idx);
150         }
151         loop {
152             let msg = self.rx.lock().unwrap().recv().unwrap();
153             match msg {
154                 Message::Run(task) => task.run(),
155                 Message::Close => break,
156             }
157         }
158         if let Some(before_stop) = before_stop {
159             before_stop(idx);
160         }
161     }
162 }
163 
164 impl Clone for ThreadPool {
clone(&self) -> Self165     fn clone(&self) -> Self {
166         self.state.cnt.fetch_add(1, Ordering::Relaxed);
167         Self { state: self.state.clone() }
168     }
169 }
170 
171 impl Drop for ThreadPool {
drop(&mut self)172     fn drop(&mut self) {
173         if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
174             for _ in 0..self.state.size {
175                 self.state.send(Message::Close);
176             }
177         }
178     }
179 }
180 
181 impl ThreadPoolBuilder {
182     /// Create a default thread pool configuration.
183     ///
184     /// See the other methods on this type for details on the defaults.
new() -> Self185     pub fn new() -> Self {
186         Self {
187             pool_size: cmp::max(1, num_cpus::get()),
188             stack_size: 0,
189             name_prefix: None,
190             after_start: None,
191             before_stop: None,
192         }
193     }
194 
195     /// Set size of a future ThreadPool
196     ///
197     /// The size of a thread pool is the number of worker threads spawned. By
198     /// default, this is equal to the number of CPU cores.
199     ///
200     /// # Panics
201     ///
202     /// Panics if `pool_size == 0`.
pool_size(&mut self, size: usize) -> &mut Self203     pub fn pool_size(&mut self, size: usize) -> &mut Self {
204         assert!(size > 0);
205         self.pool_size = size;
206         self
207     }
208 
209     /// Set stack size of threads in the pool, in bytes.
210     ///
211     /// By default, worker threads use Rust's standard stack size.
stack_size(&mut self, stack_size: usize) -> &mut Self212     pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
213         self.stack_size = stack_size;
214         self
215     }
216 
217     /// Set thread name prefix of a future ThreadPool.
218     ///
219     /// Thread name prefix is used for generating thread names. For example, if prefix is
220     /// `my-pool-`, then threads in the pool will get names like `my-pool-1` etc.
221     ///
222     /// By default, worker threads are assigned Rust's standard thread name.
name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self223     pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
224         self.name_prefix = Some(name_prefix.into());
225         self
226     }
227 
228     /// Execute the closure `f` immediately after each worker thread is started,
229     /// but before running any tasks on it.
230     ///
231     /// This hook is intended for bookkeeping and monitoring.
232     /// The closure `f` will be dropped after the `builder` is dropped
233     /// and all worker threads in the pool have executed it.
234     ///
235     /// The closure provided will receive an index corresponding to the worker
236     /// thread it's running on.
after_start<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,237     pub fn after_start<F>(&mut self, f: F) -> &mut Self
238     where
239         F: Fn(usize) + Send + Sync + 'static,
240     {
241         self.after_start = Some(Arc::new(f));
242         self
243     }
244 
245     /// Execute closure `f` just prior to shutting down each worker thread.
246     ///
247     /// This hook is intended for bookkeeping and monitoring.
248     /// The closure `f` will be dropped after the `builder` is dropped
249     /// and all threads in the pool have executed it.
250     ///
251     /// The closure provided will receive an index corresponding to the worker
252     /// thread it's running on.
before_stop<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,253     pub fn before_stop<F>(&mut self, f: F) -> &mut Self
254     where
255         F: Fn(usize) + Send + Sync + 'static,
256     {
257         self.before_stop = Some(Arc::new(f));
258         self
259     }
260 
261     /// Create a [`ThreadPool`](ThreadPool) with the given configuration.
create(&mut self) -> Result<ThreadPool, io::Error>262     pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
263         let (tx, rx) = mpsc::channel();
264         let pool = ThreadPool {
265             state: Arc::new(PoolState {
266                 tx: Mutex::new(tx),
267                 rx: Mutex::new(rx),
268                 cnt: AtomicUsize::new(1),
269                 size: self.pool_size,
270             }),
271         };
272 
273         for counter in 0..self.pool_size {
274             let state = pool.state.clone();
275             let after_start = self.after_start.clone();
276             let before_stop = self.before_stop.clone();
277             let mut thread_builder = thread::Builder::new();
278             if let Some(ref name_prefix) = self.name_prefix {
279                 thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter));
280             }
281             if self.stack_size > 0 {
282                 thread_builder = thread_builder.stack_size(self.stack_size);
283             }
284             thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
285         }
286         Ok(pool)
287     }
288 }
289 
290 impl Default for ThreadPoolBuilder {
default() -> Self291     fn default() -> Self {
292         Self::new()
293     }
294 }
295 
296 /// A task responsible for polling a future to completion.
297 struct Task {
298     future: FutureObj<'static, ()>,
299     exec: ThreadPool,
300     wake_handle: Arc<WakeHandle>,
301 }
302 
303 struct WakeHandle {
304     mutex: UnparkMutex<Task>,
305     exec: ThreadPool,
306 }
307 
308 impl Task {
309     /// Actually run the task (invoking `poll` on the future) on the current
310     /// thread.
run(self)311     fn run(self) {
312         let Self { mut future, wake_handle, mut exec } = self;
313         let waker = waker_ref(&wake_handle);
314         let mut cx = Context::from_waker(&waker);
315 
316         // Safety: The ownership of this `Task` object is evidence that
317         // we are in the `POLLING`/`REPOLL` state for the mutex.
318         unsafe {
319             wake_handle.mutex.start_poll();
320 
321             loop {
322                 let res = future.poll_unpin(&mut cx);
323                 match res {
324                     Poll::Pending => {}
325                     Poll::Ready(()) => return wake_handle.mutex.complete(),
326                 }
327                 let task = Self { future, wake_handle: wake_handle.clone(), exec };
328                 match wake_handle.mutex.wait(task) {
329                     Ok(()) => return, // we've waited
330                     Err(task) => {
331                         // someone's notified us
332                         future = task.future;
333                         exec = task.exec;
334                     }
335                 }
336             }
337         }
338     }
339 }
340 
341 impl fmt::Debug for Task {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result342     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343         f.debug_struct("Task").field("contents", &"...").finish()
344     }
345 }
346 
347 impl ArcWake for WakeHandle {
wake_by_ref(arc_self: &Arc<Self>)348     fn wake_by_ref(arc_self: &Arc<Self>) {
349         match arc_self.mutex.notify() {
350             Ok(task) => arc_self.exec.state.send(Message::Run(task)),
351             Err(()) => {}
352         }
353     }
354 }
355 
356 #[cfg(test)]
357 mod tests {
358     use super::*;
359     use std::sync::mpsc;
360 
361     #[test]
test_drop_after_start()362     fn test_drop_after_start() {
363         let (tx, rx) = mpsc::sync_channel(2);
364         let _cpu_pool = ThreadPoolBuilder::new()
365             .pool_size(2)
366             .after_start(move |_| tx.send(1).unwrap())
367             .create()
368             .unwrap();
369 
370         // After ThreadPoolBuilder is deconstructed, the tx should be dropped
371         // so that we can use rx as an iterator.
372         let count = rx.into_iter().count();
373         assert_eq!(count, 2);
374     }
375 }
376