1 #![cfg(test)]
2
3 use crate::{ThreadPoolBuildError, ThreadPoolBuilder};
4 use std::sync::atomic::{AtomicUsize, Ordering};
5 use std::sync::{Arc, Barrier};
6
7 #[test]
8 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
worker_thread_index()9 fn worker_thread_index() {
10 let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap();
11 assert_eq!(pool.current_num_threads(), 22);
12 assert_eq!(pool.current_thread_index(), None);
13 let index = pool.install(|| pool.current_thread_index().unwrap());
14 assert!(index < 22);
15 }
16
17 #[test]
18 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
start_callback_called()19 fn start_callback_called() {
20 let n_threads = 16;
21 let n_called = Arc::new(AtomicUsize::new(0));
22 // Wait for all the threads in the pool plus the one running tests.
23 let barrier = Arc::new(Barrier::new(n_threads + 1));
24
25 let b = Arc::clone(&barrier);
26 let nc = Arc::clone(&n_called);
27 let start_handler = move |_| {
28 nc.fetch_add(1, Ordering::SeqCst);
29 b.wait();
30 };
31
32 let conf = ThreadPoolBuilder::new()
33 .num_threads(n_threads)
34 .start_handler(start_handler);
35 let _ = conf.build().unwrap();
36
37 // Wait for all the threads to have been scheduled to run.
38 barrier.wait();
39
40 // The handler must have been called on every started thread.
41 assert_eq!(n_called.load(Ordering::SeqCst), n_threads);
42 }
43
44 #[test]
45 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
exit_callback_called()46 fn exit_callback_called() {
47 let n_threads = 16;
48 let n_called = Arc::new(AtomicUsize::new(0));
49 // Wait for all the threads in the pool plus the one running tests.
50 let barrier = Arc::new(Barrier::new(n_threads + 1));
51
52 let b = Arc::clone(&barrier);
53 let nc = Arc::clone(&n_called);
54 let exit_handler = move |_| {
55 nc.fetch_add(1, Ordering::SeqCst);
56 b.wait();
57 };
58
59 let conf = ThreadPoolBuilder::new()
60 .num_threads(n_threads)
61 .exit_handler(exit_handler);
62 {
63 let _ = conf.build().unwrap();
64 // Drop the pool so it stops the running threads.
65 }
66
67 // Wait for all the threads to have been scheduled to run.
68 barrier.wait();
69
70 // The handler must have been called on every exiting thread.
71 assert_eq!(n_called.load(Ordering::SeqCst), n_threads);
72 }
73
74 #[test]
75 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
handler_panics_handled_correctly()76 fn handler_panics_handled_correctly() {
77 let n_threads = 16;
78 let n_called = Arc::new(AtomicUsize::new(0));
79 // Wait for all the threads in the pool plus the one running tests.
80 let start_barrier = Arc::new(Barrier::new(n_threads + 1));
81 let exit_barrier = Arc::new(Barrier::new(n_threads + 1));
82
83 let start_handler = move |_| {
84 panic!("ensure panic handler is called when starting");
85 };
86 let exit_handler = move |_| {
87 panic!("ensure panic handler is called when exiting");
88 };
89
90 let sb = Arc::clone(&start_barrier);
91 let eb = Arc::clone(&exit_barrier);
92 let nc = Arc::clone(&n_called);
93 let panic_handler = move |_| {
94 let val = nc.fetch_add(1, Ordering::SeqCst);
95 if val < n_threads {
96 sb.wait();
97 } else {
98 eb.wait();
99 }
100 };
101
102 let conf = ThreadPoolBuilder::new()
103 .num_threads(n_threads)
104 .start_handler(start_handler)
105 .exit_handler(exit_handler)
106 .panic_handler(panic_handler);
107 {
108 let _ = conf.build().unwrap();
109
110 // Wait for all the threads to start, panic in the start handler,
111 // and been taken care of by the panic handler.
112 start_barrier.wait();
113
114 // Drop the pool so it stops the running threads.
115 }
116
117 // Wait for all the threads to exit, panic in the exit handler,
118 // and been taken care of by the panic handler.
119 exit_barrier.wait();
120
121 // The panic handler must have been called twice on every thread.
122 assert_eq!(n_called.load(Ordering::SeqCst), 2 * n_threads);
123 }
124
125 #[test]
126 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
check_config_build()127 fn check_config_build() {
128 let pool = ThreadPoolBuilder::new().num_threads(22).build().unwrap();
129 assert_eq!(pool.current_num_threads(), 22);
130 }
131
132 /// Helper used by check_error_send_sync to ensure ThreadPoolBuildError is Send + Sync
_send_sync<T: Send + Sync>()133 fn _send_sync<T: Send + Sync>() {}
134
135 #[test]
check_error_send_sync()136 fn check_error_send_sync() {
137 _send_sync::<ThreadPoolBuildError>();
138 }
139
140 #[allow(deprecated)]
141 #[test]
142 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
configuration()143 fn configuration() {
144 let start_handler = move |_| {};
145 let exit_handler = move |_| {};
146 let panic_handler = move |_| {};
147 let thread_name = move |i| format!("thread_name_{}", i);
148
149 // Ensure we can call all public methods on Configuration
150 crate::Configuration::new()
151 .thread_name(thread_name)
152 .num_threads(5)
153 .panic_handler(panic_handler)
154 .stack_size(4e6 as usize)
155 .breadth_first()
156 .start_handler(start_handler)
157 .exit_handler(exit_handler)
158 .build()
159 .unwrap();
160 }
161
162 #[test]
163 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
default_pool()164 fn default_pool() {
165 ThreadPoolBuilder::default().build().unwrap();
166 }
167
168 /// Test that custom spawned threads get their `WorkerThread` cleared once
169 /// the pool is done with them, allowing them to be used with rayon again
170 /// later. e.g. WebAssembly want to have their own pool of available threads.
171 #[test]
172 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
cleared_current_thread() -> Result<(), ThreadPoolBuildError>173 fn cleared_current_thread() -> Result<(), ThreadPoolBuildError> {
174 let n_threads = 5;
175 let mut handles = vec![];
176 let pool = ThreadPoolBuilder::new()
177 .num_threads(n_threads)
178 .spawn_handler(|thread| {
179 let handle = std::thread::spawn(move || {
180 thread.run();
181
182 // Afterward, the current thread shouldn't be set anymore.
183 assert_eq!(crate::current_thread_index(), None);
184 });
185 handles.push(handle);
186 Ok(())
187 })
188 .build()?;
189 assert_eq!(handles.len(), n_threads);
190
191 pool.install(|| assert!(crate::current_thread_index().is_some()));
192 drop(pool);
193
194 // Wait for all threads to make their assertions and exit
195 for handle in handles {
196 handle.join().unwrap();
197 }
198
199 Ok(())
200 }
201