1 // Copyright 2023 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 use std::{ 16 sync::{mpsc, Arc, Mutex}, 17 thread, 18 }; 19 20 pub struct ThreadPool { 21 workers: Vec<Worker>, 22 sender: Option<mpsc::Sender<Job>>, 23 } 24 25 type Job = Box<dyn FnOnce() + Send + 'static>; 26 27 impl ThreadPool { 28 /// Create a new ThreadPool. 29 /// 30 /// The size is the number of threads in the pool. 31 /// 32 /// # Panics 33 /// 34 /// The `new` function will panic if the size is zero. new(size: usize) -> ThreadPool35 pub fn new(size: usize) -> ThreadPool { 36 assert!(size > 0); 37 38 let (sender, receiver) = mpsc::channel(); 39 40 let receiver = Arc::new(Mutex::new(receiver)); 41 42 let mut workers = Vec::with_capacity(size); 43 44 for id in 0..size { 45 workers.push(Worker::new(id, Arc::clone(&receiver))); 46 } 47 48 ThreadPool { workers, sender: Some(sender) } 49 } 50 execute<F>(&self, f: F) where F: FnOnce() + Send + 'static,51 pub fn execute<F>(&self, f: F) 52 where 53 F: FnOnce() + Send + 'static, 54 { 55 let job = Box::new(f); 56 57 self.sender.as_ref().unwrap().send(job).unwrap(); 58 } 59 } 60 61 impl Drop for ThreadPool { drop(&mut self)62 fn drop(&mut self) { 63 drop(self.sender.take()); 64 65 for worker in &mut self.workers { 66 println!("Shutting down worker {}", worker.id); 67 68 if let Some(thread) = worker.thread.take() { 69 thread.join().unwrap(); 70 } 71 } 72 } 73 } 74 75 struct Worker { 76 id: usize, 77 thread: Option<thread::JoinHandle<()>>, 78 } 79 80 impl Worker { new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker81 fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker { 82 let thread = thread::spawn(move || loop { 83 let message = receiver.lock().unwrap().recv(); 84 85 match message { 86 Ok(job) => { 87 job(); 88 } 89 Err(_) => { 90 println!("Worker {id} disconnected; shutting down."); 91 break; 92 } 93 } 94 }); 95 96 Worker { id, thread: Some(thread) } 97 } 98 } 99