1 #![cfg(test)]
2
3 #[allow(deprecated)]
4 use crate::Configuration;
5 use crate::{ThreadPoolBuildError, ThreadPoolBuilder};
6 use std::sync::atomic::{AtomicUsize, Ordering};
7 use std::sync::{Arc, Barrier};
8
9 #[test]
worker_thread_index()10 fn worker_thread_index() {
11 let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap();
12 assert_eq!(pool.current_num_threads(), 22);
13 assert_eq!(pool.current_thread_index(), None);
14 let index = pool.install(|| pool.current_thread_index().unwrap());
15 assert!(index < 22);
16 }
17
18 #[test]
start_callback_called()19 fn start_callback_called() {
20 let n_threads = 16;
21 let n_called = Arc::new(AtomicUsize::new(0));
22 // Wait for all the threads in the pool plus the one running tests.
23 let barrier = Arc::new(Barrier::new(n_threads + 1));
24
25 let b = barrier.clone();
26 let nc = n_called.clone();
27 let start_handler = move |_| {
28 nc.fetch_add(1, Ordering::SeqCst);
29 b.wait();
30 };
31
32 let conf = ThreadPoolBuilder::new()
33 .num_threads(n_threads)
34 .start_handler(start_handler);
35 let _ = conf.build().unwrap();
36
37 // Wait for all the threads to have been scheduled to run.
38 barrier.wait();
39
40 // The handler must have been called on every started thread.
41 assert_eq!(n_called.load(Ordering::SeqCst), n_threads);
42 }
43
44 #[test]
exit_callback_called()45 fn exit_callback_called() {
46 let n_threads = 16;
47 let n_called = Arc::new(AtomicUsize::new(0));
48 // Wait for all the threads in the pool plus the one running tests.
49 let barrier = Arc::new(Barrier::new(n_threads + 1));
50
51 let b = barrier.clone();
52 let nc = n_called.clone();
53 let exit_handler = move |_| {
54 nc.fetch_add(1, Ordering::SeqCst);
55 b.wait();
56 };
57
58 let conf = ThreadPoolBuilder::new()
59 .num_threads(n_threads)
60 .exit_handler(exit_handler);
61 {
62 let _ = conf.build().unwrap();
63 // Drop the pool so it stops the running threads.
64 }
65
66 // Wait for all the threads to have been scheduled to run.
67 barrier.wait();
68
69 // The handler must have been called on every exiting thread.
70 assert_eq!(n_called.load(Ordering::SeqCst), n_threads);
71 }
72
73 #[test]
handler_panics_handled_correctly()74 fn handler_panics_handled_correctly() {
75 let n_threads = 16;
76 let n_called = Arc::new(AtomicUsize::new(0));
77 // Wait for all the threads in the pool plus the one running tests.
78 let start_barrier = Arc::new(Barrier::new(n_threads + 1));
79 let exit_barrier = Arc::new(Barrier::new(n_threads + 1));
80
81 let start_handler = move |_| {
82 panic!("ensure panic handler is called when starting");
83 };
84 let exit_handler = move |_| {
85 panic!("ensure panic handler is called when exiting");
86 };
87
88 let sb = start_barrier.clone();
89 let eb = exit_barrier.clone();
90 let nc = n_called.clone();
91 let panic_handler = move |_| {
92 let val = nc.fetch_add(1, Ordering::SeqCst);
93 if val < n_threads {
94 sb.wait();
95 } else {
96 eb.wait();
97 }
98 };
99
100 let conf = ThreadPoolBuilder::new()
101 .num_threads(n_threads)
102 .start_handler(start_handler)
103 .exit_handler(exit_handler)
104 .panic_handler(panic_handler);
105 {
106 let _ = conf.build().unwrap();
107
108 // Wait for all the threads to start, panic in the start handler,
109 // and been taken care of by the panic handler.
110 start_barrier.wait();
111
112 // Drop the pool so it stops the running threads.
113 }
114
115 // Wait for all the threads to exit, panic in the exit handler,
116 // and been taken care of by the panic handler.
117 exit_barrier.wait();
118
119 // The panic handler must have been called twice on every thread.
120 assert_eq!(n_called.load(Ordering::SeqCst), 2 * n_threads);
121 }
122
123 #[test]
124 #[allow(deprecated)]
check_config_build()125 fn check_config_build() {
126 let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap();
127 assert_eq!(pool.current_num_threads(), 22);
128 }
129
130 /// Helper used by check_error_send_sync to ensure ThreadPoolBuildError is Send + Sync
_send_sync<T: Send + Sync>()131 fn _send_sync<T: Send + Sync>() {}
132
133 #[test]
check_error_send_sync()134 fn check_error_send_sync() {
135 _send_sync::<ThreadPoolBuildError>();
136 }
137
138 #[allow(deprecated)]
139 #[test]
configuration()140 fn configuration() {
141 let start_handler = move |_| {};
142 let exit_handler = move |_| {};
143 let panic_handler = move |_| {};
144 let thread_name = move |i| format!("thread_name_{}", i);
145
146 // Ensure we can call all public methods on Configuration
147 Configuration::new()
148 .thread_name(thread_name)
149 .num_threads(5)
150 .panic_handler(panic_handler)
151 .stack_size(4e6 as usize)
152 .breadth_first()
153 .start_handler(start_handler)
154 .exit_handler(exit_handler)
155 .build()
156 .unwrap();
157 }
158
159 #[test]
default_pool()160 fn default_pool() {
161 ThreadPoolBuilder::default().build().unwrap();
162 }
163
164 /// Test that custom spawned threads get their `WorkerThread` cleared once
165 /// the pool is done with them, allowing them to be used with rayon again
166 /// later. e.g. WebAssembly want to have their own pool of available threads.
167 #[test]
cleared_current_thread() -> Result<(), ThreadPoolBuildError>168 fn cleared_current_thread() -> Result<(), ThreadPoolBuildError> {
169 let n_threads = 5;
170 let mut handles = vec![];
171 let pool = ThreadPoolBuilder::new()
172 .num_threads(n_threads)
173 .spawn_handler(|thread| {
174 let handle = std::thread::spawn(move || {
175 thread.run();
176
177 // Afterward, the current thread shouldn't be set anymore.
178 assert_eq!(crate::current_thread_index(), None);
179 });
180 handles.push(handle);
181 Ok(())
182 })
183 .build()?;
184 assert_eq!(handles.len(), n_threads);
185
186 pool.install(|| assert!(crate::current_thread_index().is_some()));
187 drop(pool);
188
189 // Wait for all threads to make their assertions and exit
190 for handle in handles {
191 handle.join().unwrap();
192 }
193
194 Ok(())
195 }
196