1 use crate::enter; 2 use crate::unpark_mutex::UnparkMutex; 3 use futures_core::future::Future; 4 use futures_core::task::{Context, Poll}; 5 use futures_task::{waker_ref, ArcWake}; 6 use futures_task::{FutureObj, Spawn, SpawnError}; 7 use futures_util::future::FutureExt; 8 use std::cmp; 9 use std::fmt; 10 use std::io; 11 use std::sync::atomic::{AtomicUsize, Ordering}; 12 use std::sync::mpsc; 13 use std::sync::{Arc, Mutex}; 14 use std::thread; 15 16 /// A general-purpose thread pool for scheduling tasks that poll futures to 17 /// completion. 18 /// 19 /// The thread pool multiplexes any number of tasks onto a fixed number of 20 /// worker threads. 21 /// 22 /// This type is a clonable handle to the threadpool itself. 23 /// Cloning it will only create a new reference, not a new threadpool. 24 /// 25 /// This type is only available when the `thread-pool` feature of this 26 /// library is activated. 27 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))] 28 pub struct ThreadPool { 29 state: Arc<PoolState>, 30 } 31 32 /// Thread pool configuration object. 33 /// 34 /// This type is only available when the `thread-pool` feature of this 35 /// library is activated. 36 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))] 37 pub struct ThreadPoolBuilder { 38 pool_size: usize, 39 stack_size: usize, 40 name_prefix: Option<String>, 41 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, 42 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, 43 } 44 45 trait AssertSendSync: Send + Sync {} 46 impl AssertSendSync for ThreadPool {} 47 48 struct PoolState { 49 tx: Mutex<mpsc::Sender<Message>>, 50 rx: Mutex<mpsc::Receiver<Message>>, 51 cnt: AtomicUsize, 52 size: usize, 53 } 54 55 impl fmt::Debug for ThreadPool { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 57 f.debug_struct("ThreadPool").field("size", &self.state.size).finish() 58 } 59 } 60 61 impl fmt::Debug for ThreadPoolBuilder { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 63 f.debug_struct("ThreadPoolBuilder") 64 .field("pool_size", &self.pool_size) 65 .field("name_prefix", &self.name_prefix) 66 .finish() 67 } 68 } 69 70 enum Message { 71 Run(Task), 72 Close, 73 } 74 75 impl ThreadPool { 76 /// Creates a new thread pool with the default configuration. 77 /// 78 /// See documentation for the methods in 79 /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default 80 /// configuration. new() -> Result<Self, io::Error>81 pub fn new() -> Result<Self, io::Error> { 82 ThreadPoolBuilder::new().create() 83 } 84 85 /// Create a default thread pool configuration, which can then be customized. 86 /// 87 /// See documentation for the methods in 88 /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default 89 /// configuration. builder() -> ThreadPoolBuilder90 pub fn builder() -> ThreadPoolBuilder { 91 ThreadPoolBuilder::new() 92 } 93 94 /// Spawns a future that will be run to completion. 95 /// 96 /// > **Note**: This method is similar to `Spawn::spawn_obj`, except that 97 /// > it is guaranteed to always succeed. spawn_obj_ok(&self, future: FutureObj<'static, ()>)98 pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) { 99 let task = Task { 100 future, 101 wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }), 102 exec: self.clone(), 103 }; 104 self.state.send(Message::Run(task)); 105 } 106 107 /// Spawns a task that polls the given future with output `()` to 108 /// completion. 109 /// 110 /// ``` 111 /// # { 112 /// use futures::executor::ThreadPool; 113 /// 114 /// let pool = ThreadPool::new().unwrap(); 115 /// 116 /// let future = async { /* ... */ }; 117 /// pool.spawn_ok(future); 118 /// # } 119 /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371 120 /// ``` 121 /// 122 /// > **Note**: This method is similar to `SpawnExt::spawn`, except that 123 /// > it is guaranteed to always succeed. spawn_ok<Fut>(&self, future: Fut) where Fut: Future<Output = ()> + Send + 'static,124 pub fn spawn_ok<Fut>(&self, future: Fut) 125 where 126 Fut: Future<Output = ()> + Send + 'static, 127 { 128 self.spawn_obj_ok(FutureObj::new(Box::new(future))) 129 } 130 } 131 132 impl Spawn for ThreadPool { spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError>133 fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> { 134 self.spawn_obj_ok(future); 135 Ok(()) 136 } 137 } 138 139 impl PoolState { send(&self, msg: Message)140 fn send(&self, msg: Message) { 141 self.tx.lock().unwrap().send(msg).unwrap(); 142 } 143 work( &self, idx: usize, after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, )144 fn work( 145 &self, 146 idx: usize, 147 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, 148 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, 149 ) { 150 let _scope = enter().unwrap(); 151 if let Some(after_start) = after_start { 152 after_start(idx); 153 } 154 loop { 155 let msg = self.rx.lock().unwrap().recv().unwrap(); 156 match msg { 157 Message::Run(task) => task.run(), 158 Message::Close => break, 159 } 160 } 161 if let Some(before_stop) = before_stop { 162 before_stop(idx); 163 } 164 } 165 } 166 167 impl Clone for ThreadPool { clone(&self) -> Self168 fn clone(&self) -> Self { 169 self.state.cnt.fetch_add(1, Ordering::Relaxed); 170 Self { state: self.state.clone() } 171 } 172 } 173 174 impl Drop for ThreadPool { drop(&mut self)175 fn drop(&mut self) { 176 if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 { 177 for _ in 0..self.state.size { 178 self.state.send(Message::Close); 179 } 180 } 181 } 182 } 183 184 impl ThreadPoolBuilder { 185 /// Create a default thread pool configuration. 186 /// 187 /// See the other methods on this type for details on the defaults. new() -> Self188 pub fn new() -> Self { 189 Self { 190 pool_size: cmp::max(1, num_cpus::get()), 191 stack_size: 0, 192 name_prefix: None, 193 after_start: None, 194 before_stop: None, 195 } 196 } 197 198 /// Set size of a future ThreadPool 199 /// 200 /// The size of a thread pool is the number of worker threads spawned. By 201 /// default, this is equal to the number of CPU cores. 202 /// 203 /// # Panics 204 /// 205 /// Panics if `pool_size == 0`. pool_size(&mut self, size: usize) -> &mut Self206 pub fn pool_size(&mut self, size: usize) -> &mut Self { 207 assert!(size > 0); 208 self.pool_size = size; 209 self 210 } 211 212 /// Set stack size of threads in the pool, in bytes. 213 /// 214 /// By default, worker threads use Rust's standard stack size. stack_size(&mut self, stack_size: usize) -> &mut Self215 pub fn stack_size(&mut self, stack_size: usize) -> &mut Self { 216 self.stack_size = stack_size; 217 self 218 } 219 220 /// Set thread name prefix of a future ThreadPool. 221 /// 222 /// Thread name prefix is used for generating thread names. For example, if prefix is 223 /// `my-pool-`, then threads in the pool will get names like `my-pool-1` etc. 224 /// 225 /// By default, worker threads are assigned Rust's standard thread name. name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self226 pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self { 227 self.name_prefix = Some(name_prefix.into()); 228 self 229 } 230 231 /// Execute the closure `f` immediately after each worker thread is started, 232 /// but before running any tasks on it. 233 /// 234 /// This hook is intended for bookkeeping and monitoring. 235 /// The closure `f` will be dropped after the `builder` is dropped 236 /// and all worker threads in the pool have executed it. 237 /// 238 /// The closure provided will receive an index corresponding to the worker 239 /// thread it's running on. after_start<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,240 pub fn after_start<F>(&mut self, f: F) -> &mut Self 241 where 242 F: Fn(usize) + Send + Sync + 'static, 243 { 244 self.after_start = Some(Arc::new(f)); 245 self 246 } 247 248 /// Execute closure `f` just prior to shutting down each worker thread. 249 /// 250 /// This hook is intended for bookkeeping and monitoring. 251 /// The closure `f` will be dropped after the `builder` is dropped 252 /// and all threads in the pool have executed it. 253 /// 254 /// The closure provided will receive an index corresponding to the worker 255 /// thread it's running on. before_stop<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,256 pub fn before_stop<F>(&mut self, f: F) -> &mut Self 257 where 258 F: Fn(usize) + Send + Sync + 'static, 259 { 260 self.before_stop = Some(Arc::new(f)); 261 self 262 } 263 264 /// Create a [`ThreadPool`](ThreadPool) with the given configuration. create(&mut self) -> Result<ThreadPool, io::Error>265 pub fn create(&mut self) -> Result<ThreadPool, io::Error> { 266 let (tx, rx) = mpsc::channel(); 267 let pool = ThreadPool { 268 state: Arc::new(PoolState { 269 tx: Mutex::new(tx), 270 rx: Mutex::new(rx), 271 cnt: AtomicUsize::new(1), 272 size: self.pool_size, 273 }), 274 }; 275 276 for counter in 0..self.pool_size { 277 let state = pool.state.clone(); 278 let after_start = self.after_start.clone(); 279 let before_stop = self.before_stop.clone(); 280 let mut thread_builder = thread::Builder::new(); 281 if let Some(ref name_prefix) = self.name_prefix { 282 thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter)); 283 } 284 if self.stack_size > 0 { 285 thread_builder = thread_builder.stack_size(self.stack_size); 286 } 287 thread_builder.spawn(move || state.work(counter, after_start, before_stop))?; 288 } 289 Ok(pool) 290 } 291 } 292 293 impl Default for ThreadPoolBuilder { default() -> Self294 fn default() -> Self { 295 Self::new() 296 } 297 } 298 299 /// A task responsible for polling a future to completion. 300 struct Task { 301 future: FutureObj<'static, ()>, 302 exec: ThreadPool, 303 wake_handle: Arc<WakeHandle>, 304 } 305 306 struct WakeHandle { 307 mutex: UnparkMutex<Task>, 308 exec: ThreadPool, 309 } 310 311 impl Task { 312 /// Actually run the task (invoking `poll` on the future) on the current 313 /// thread. run(self)314 fn run(self) { 315 let Self { mut future, wake_handle, mut exec } = self; 316 let waker = waker_ref(&wake_handle); 317 let mut cx = Context::from_waker(&waker); 318 319 // Safety: The ownership of this `Task` object is evidence that 320 // we are in the `POLLING`/`REPOLL` state for the mutex. 321 unsafe { 322 wake_handle.mutex.start_poll(); 323 324 loop { 325 let res = future.poll_unpin(&mut cx); 326 match res { 327 Poll::Pending => {} 328 Poll::Ready(()) => return wake_handle.mutex.complete(), 329 } 330 let task = Self { future, wake_handle: wake_handle.clone(), exec }; 331 match wake_handle.mutex.wait(task) { 332 Ok(()) => return, // we've waited 333 Err(task) => { 334 // someone's notified us 335 future = task.future; 336 exec = task.exec; 337 } 338 } 339 } 340 } 341 } 342 } 343 344 impl fmt::Debug for Task { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result345 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 346 f.debug_struct("Task").field("contents", &"...").finish() 347 } 348 } 349 350 impl ArcWake for WakeHandle { wake_by_ref(arc_self: &Arc<Self>)351 fn wake_by_ref(arc_self: &Arc<Self>) { 352 if let Ok(task) = arc_self.mutex.notify() { 353 arc_self.exec.state.send(Message::Run(task)) 354 } 355 } 356 } 357 358 #[cfg(test)] 359 mod tests { 360 use super::*; 361 use std::sync::mpsc; 362 363 #[test] test_drop_after_start()364 fn test_drop_after_start() { 365 { 366 let (tx, rx) = mpsc::sync_channel(2); 367 let _cpu_pool = ThreadPoolBuilder::new() 368 .pool_size(2) 369 .after_start(move |_| tx.send(1).unwrap()) 370 .create() 371 .unwrap(); 372 373 // After ThreadPoolBuilder is deconstructed, the tx should be dropped 374 // so that we can use rx as an iterator. 375 let count = rx.into_iter().count(); 376 assert_eq!(count, 2); 377 } 378 std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371 379 } 380 } 381