• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::{
16     sync::{mpsc, Arc, Mutex},
17     thread,
18 };
19 
20 pub struct ThreadPool {
21     workers: Vec<Worker>,
22     sender: Option<mpsc::Sender<Job>>,
23 }
24 
25 type Job = Box<dyn FnOnce() + Send + 'static>;
26 
27 impl ThreadPool {
28     /// Create a new ThreadPool.
29     ///
30     /// The size is the number of threads in the pool.
31     ///
32     /// # Panics
33     ///
34     /// The `new` function will panic if the size is zero.
new(size: usize) -> ThreadPool35     pub fn new(size: usize) -> ThreadPool {
36         assert!(size > 0);
37 
38         let (sender, receiver) = mpsc::channel();
39 
40         let receiver = Arc::new(Mutex::new(receiver));
41 
42         let mut workers = Vec::with_capacity(size);
43 
44         for id in 0..size {
45             workers.push(Worker::new(id, Arc::clone(&receiver)));
46         }
47 
48         ThreadPool { workers, sender: Some(sender) }
49     }
50 
execute<F>(&self, f: F) where F: FnOnce() + Send + 'static,51     pub fn execute<F>(&self, f: F)
52     where
53         F: FnOnce() + Send + 'static,
54     {
55         let job = Box::new(f);
56 
57         self.sender.as_ref().unwrap().send(job).unwrap();
58     }
59 }
60 
61 impl Drop for ThreadPool {
drop(&mut self)62     fn drop(&mut self) {
63         drop(self.sender.take());
64 
65         for worker in &mut self.workers {
66             println!("Shutting down worker {}", worker.id);
67 
68             if let Some(thread) = worker.thread.take() {
69                 thread.join().unwrap();
70             }
71         }
72     }
73 }
74 
75 struct Worker {
76     id: usize,
77     thread: Option<thread::JoinHandle<()>>,
78 }
79 
80 impl Worker {
new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker81     fn new(id: usize, receiver: Arc<Mutex<mpsc::Receiver<Job>>>) -> Worker {
82         let thread = thread::spawn(move || loop {
83             let message = receiver.lock().unwrap().recv();
84 
85             match message {
86                 Ok(job) => {
87                     job();
88                 }
89                 Err(_) => {
90                     println!("Worker {id} disconnected; shutting down.");
91                     break;
92                 }
93             }
94         });
95 
96         Worker { id, thread: Some(thread) }
97     }
98 }
99