1 use crate::enter; 2 use crate::unpark_mutex::UnparkMutex; 3 use futures_core::future::Future; 4 use futures_core::task::{Context, Poll}; 5 use futures_task::{waker_ref, ArcWake}; 6 use futures_task::{FutureObj, Spawn, SpawnError}; 7 use futures_util::future::FutureExt; 8 use std::cmp; 9 use std::fmt; 10 use std::io; 11 use std::sync::atomic::{AtomicUsize, Ordering}; 12 use std::sync::mpsc; 13 use std::sync::{Arc, Mutex}; 14 use std::thread; 15 16 /// A general-purpose thread pool for scheduling tasks that poll futures to 17 /// completion. 18 /// 19 /// The thread pool multiplexes any number of tasks onto a fixed number of 20 /// worker threads. 21 /// 22 /// This type is a clonable handle to the threadpool itself. 23 /// Cloning it will only create a new reference, not a new threadpool. 24 /// 25 /// This type is only available when the `thread-pool` feature of this 26 /// library is activated. 27 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))] 28 pub struct ThreadPool { 29 state: Arc<PoolState>, 30 } 31 32 /// Thread pool configuration object. 33 /// 34 /// This type is only available when the `thread-pool` feature of this 35 /// library is activated. 36 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))] 37 pub struct ThreadPoolBuilder { 38 pool_size: usize, 39 stack_size: usize, 40 name_prefix: Option<String>, 41 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, 42 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, 43 } 44 45 trait AssertSendSync: Send + Sync {} 46 impl AssertSendSync for ThreadPool {} 47 48 struct PoolState { 49 tx: Mutex<mpsc::Sender<Message>>, 50 rx: Mutex<mpsc::Receiver<Message>>, 51 cnt: AtomicUsize, 52 size: usize, 53 } 54 55 impl fmt::Debug for ThreadPool { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 57 f.debug_struct("ThreadPool").field("size", &self.state.size).finish() 58 } 59 } 60 61 impl fmt::Debug for ThreadPoolBuilder { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 63 f.debug_struct("ThreadPoolBuilder") 64 .field("pool_size", &self.pool_size) 65 .field("name_prefix", &self.name_prefix) 66 .finish() 67 } 68 } 69 70 enum Message { 71 Run(Task), 72 Close, 73 } 74 75 impl ThreadPool { 76 /// Creates a new thread pool with the default configuration. 77 /// 78 /// See documentation for the methods in 79 /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default 80 /// configuration. new() -> Result<Self, io::Error>81 pub fn new() -> Result<Self, io::Error> { 82 ThreadPoolBuilder::new().create() 83 } 84 85 /// Create a default thread pool configuration, which can then be customized. 86 /// 87 /// See documentation for the methods in 88 /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default 89 /// configuration. builder() -> ThreadPoolBuilder90 pub fn builder() -> ThreadPoolBuilder { 91 ThreadPoolBuilder::new() 92 } 93 94 /// Spawns a future that will be run to completion. 95 /// 96 /// > **Note**: This method is similar to `Spawn::spawn_obj`, except that 97 /// > it is guaranteed to always succeed. spawn_obj_ok(&self, future: FutureObj<'static, ()>)98 pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) { 99 let task = Task { 100 future, 101 wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }), 102 exec: self.clone(), 103 }; 104 self.state.send(Message::Run(task)); 105 } 106 107 /// Spawns a task that polls the given future with output `()` to 108 /// completion. 109 /// 110 /// ``` 111 /// use futures::executor::ThreadPool; 112 /// 113 /// let pool = ThreadPool::new().unwrap(); 114 /// 115 /// let future = async { /* ... */ }; 116 /// pool.spawn_ok(future); 117 /// ``` 118 /// 119 /// > **Note**: This method is similar to `SpawnExt::spawn`, except that 120 /// > it is guaranteed to always succeed. spawn_ok<Fut>(&self, future: Fut) where Fut: Future<Output = ()> + Send + 'static,121 pub fn spawn_ok<Fut>(&self, future: Fut) 122 where 123 Fut: Future<Output = ()> + Send + 'static, 124 { 125 self.spawn_obj_ok(FutureObj::new(Box::new(future))) 126 } 127 } 128 129 impl Spawn for ThreadPool { spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError>130 fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> { 131 self.spawn_obj_ok(future); 132 Ok(()) 133 } 134 } 135 136 impl PoolState { send(&self, msg: Message)137 fn send(&self, msg: Message) { 138 self.tx.lock().unwrap().send(msg).unwrap(); 139 } 140 work( &self, idx: usize, after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, )141 fn work( 142 &self, 143 idx: usize, 144 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, 145 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, 146 ) { 147 let _scope = enter().unwrap(); 148 if let Some(after_start) = after_start { 149 after_start(idx); 150 } 151 loop { 152 let msg = self.rx.lock().unwrap().recv().unwrap(); 153 match msg { 154 Message::Run(task) => task.run(), 155 Message::Close => break, 156 } 157 } 158 if let Some(before_stop) = before_stop { 159 before_stop(idx); 160 } 161 } 162 } 163 164 impl Clone for ThreadPool { clone(&self) -> Self165 fn clone(&self) -> Self { 166 self.state.cnt.fetch_add(1, Ordering::Relaxed); 167 Self { state: self.state.clone() } 168 } 169 } 170 171 impl Drop for ThreadPool { drop(&mut self)172 fn drop(&mut self) { 173 if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 { 174 for _ in 0..self.state.size { 175 self.state.send(Message::Close); 176 } 177 } 178 } 179 } 180 181 impl ThreadPoolBuilder { 182 /// Create a default thread pool configuration. 183 /// 184 /// See the other methods on this type for details on the defaults. new() -> Self185 pub fn new() -> Self { 186 Self { 187 pool_size: cmp::max(1, num_cpus::get()), 188 stack_size: 0, 189 name_prefix: None, 190 after_start: None, 191 before_stop: None, 192 } 193 } 194 195 /// Set size of a future ThreadPool 196 /// 197 /// The size of a thread pool is the number of worker threads spawned. By 198 /// default, this is equal to the number of CPU cores. 199 /// 200 /// # Panics 201 /// 202 /// Panics if `pool_size == 0`. pool_size(&mut self, size: usize) -> &mut Self203 pub fn pool_size(&mut self, size: usize) -> &mut Self { 204 assert!(size > 0); 205 self.pool_size = size; 206 self 207 } 208 209 /// Set stack size of threads in the pool, in bytes. 210 /// 211 /// By default, worker threads use Rust's standard stack size. stack_size(&mut self, stack_size: usize) -> &mut Self212 pub fn stack_size(&mut self, stack_size: usize) -> &mut Self { 213 self.stack_size = stack_size; 214 self 215 } 216 217 /// Set thread name prefix of a future ThreadPool. 218 /// 219 /// Thread name prefix is used for generating thread names. For example, if prefix is 220 /// `my-pool-`, then threads in the pool will get names like `my-pool-1` etc. 221 /// 222 /// By default, worker threads are assigned Rust's standard thread name. name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self223 pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self { 224 self.name_prefix = Some(name_prefix.into()); 225 self 226 } 227 228 /// Execute the closure `f` immediately after each worker thread is started, 229 /// but before running any tasks on it. 230 /// 231 /// This hook is intended for bookkeeping and monitoring. 232 /// The closure `f` will be dropped after the `builder` is dropped 233 /// and all worker threads in the pool have executed it. 234 /// 235 /// The closure provided will receive an index corresponding to the worker 236 /// thread it's running on. after_start<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,237 pub fn after_start<F>(&mut self, f: F) -> &mut Self 238 where 239 F: Fn(usize) + Send + Sync + 'static, 240 { 241 self.after_start = Some(Arc::new(f)); 242 self 243 } 244 245 /// Execute closure `f` just prior to shutting down each worker thread. 246 /// 247 /// This hook is intended for bookkeeping and monitoring. 248 /// The closure `f` will be dropped after the `builder` is dropped 249 /// and all threads in the pool have executed it. 250 /// 251 /// The closure provided will receive an index corresponding to the worker 252 /// thread it's running on. before_stop<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,253 pub fn before_stop<F>(&mut self, f: F) -> &mut Self 254 where 255 F: Fn(usize) + Send + Sync + 'static, 256 { 257 self.before_stop = Some(Arc::new(f)); 258 self 259 } 260 261 /// Create a [`ThreadPool`](ThreadPool) with the given configuration. create(&mut self) -> Result<ThreadPool, io::Error>262 pub fn create(&mut self) -> Result<ThreadPool, io::Error> { 263 let (tx, rx) = mpsc::channel(); 264 let pool = ThreadPool { 265 state: Arc::new(PoolState { 266 tx: Mutex::new(tx), 267 rx: Mutex::new(rx), 268 cnt: AtomicUsize::new(1), 269 size: self.pool_size, 270 }), 271 }; 272 273 for counter in 0..self.pool_size { 274 let state = pool.state.clone(); 275 let after_start = self.after_start.clone(); 276 let before_stop = self.before_stop.clone(); 277 let mut thread_builder = thread::Builder::new(); 278 if let Some(ref name_prefix) = self.name_prefix { 279 thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter)); 280 } 281 if self.stack_size > 0 { 282 thread_builder = thread_builder.stack_size(self.stack_size); 283 } 284 thread_builder.spawn(move || state.work(counter, after_start, before_stop))?; 285 } 286 Ok(pool) 287 } 288 } 289 290 impl Default for ThreadPoolBuilder { default() -> Self291 fn default() -> Self { 292 Self::new() 293 } 294 } 295 296 /// A task responsible for polling a future to completion. 297 struct Task { 298 future: FutureObj<'static, ()>, 299 exec: ThreadPool, 300 wake_handle: Arc<WakeHandle>, 301 } 302 303 struct WakeHandle { 304 mutex: UnparkMutex<Task>, 305 exec: ThreadPool, 306 } 307 308 impl Task { 309 /// Actually run the task (invoking `poll` on the future) on the current 310 /// thread. run(self)311 fn run(self) { 312 let Self { mut future, wake_handle, mut exec } = self; 313 let waker = waker_ref(&wake_handle); 314 let mut cx = Context::from_waker(&waker); 315 316 // Safety: The ownership of this `Task` object is evidence that 317 // we are in the `POLLING`/`REPOLL` state for the mutex. 318 unsafe { 319 wake_handle.mutex.start_poll(); 320 321 loop { 322 let res = future.poll_unpin(&mut cx); 323 match res { 324 Poll::Pending => {} 325 Poll::Ready(()) => return wake_handle.mutex.complete(), 326 } 327 let task = Self { future, wake_handle: wake_handle.clone(), exec }; 328 match wake_handle.mutex.wait(task) { 329 Ok(()) => return, // we've waited 330 Err(task) => { 331 // someone's notified us 332 future = task.future; 333 exec = task.exec; 334 } 335 } 336 } 337 } 338 } 339 } 340 341 impl fmt::Debug for Task { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result342 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 343 f.debug_struct("Task").field("contents", &"...").finish() 344 } 345 } 346 347 impl ArcWake for WakeHandle { wake_by_ref(arc_self: &Arc<Self>)348 fn wake_by_ref(arc_self: &Arc<Self>) { 349 match arc_self.mutex.notify() { 350 Ok(task) => arc_self.exec.state.send(Message::Run(task)), 351 Err(()) => {} 352 } 353 } 354 } 355 356 #[cfg(test)] 357 mod tests { 358 use super::*; 359 use std::sync::mpsc; 360 361 #[test] test_drop_after_start()362 fn test_drop_after_start() { 363 let (tx, rx) = mpsc::sync_channel(2); 364 let _cpu_pool = ThreadPoolBuilder::new() 365 .pool_size(2) 366 .after_start(move |_| tx.send(1).unwrap()) 367 .create() 368 .unwrap(); 369 370 // After ThreadPoolBuilder is deconstructed, the tx should be dropped 371 // so that we can use rx as an iterator. 372 let count = rx.into_iter().count(); 373 assert_eq!(count, 2); 374 } 375 } 376