• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Tests for the join code.
2 
3 use crate::join::*;
4 use crate::unwind;
5 use crate::ThreadPoolBuilder;
6 use rand::distributions::Standard;
7 use rand::{Rng, SeedableRng};
8 use rand_xorshift::XorShiftRng;
9 
quick_sort<T: PartialOrd + Send>(v: &mut [T])10 fn quick_sort<T: PartialOrd + Send>(v: &mut [T]) {
11     if v.len() <= 1 {
12         return;
13     }
14 
15     let mid = partition(v);
16     let (lo, hi) = v.split_at_mut(mid);
17     join(|| quick_sort(lo), || quick_sort(hi));
18 }
19 
partition<T: PartialOrd + Send>(v: &mut [T]) -> usize20 fn partition<T: PartialOrd + Send>(v: &mut [T]) -> usize {
21     let pivot = v.len() - 1;
22     let mut i = 0;
23     for j in 0..pivot {
24         if v[j] <= v[pivot] {
25             v.swap(i, j);
26             i += 1;
27         }
28     }
29     v.swap(i, pivot);
30     i
31 }
32 
seeded_rng() -> XorShiftRng33 fn seeded_rng() -> XorShiftRng {
34     let mut seed = <XorShiftRng as SeedableRng>::Seed::default();
35     (0..).zip(seed.as_mut()).for_each(|(i, x)| *x = i);
36     XorShiftRng::from_seed(seed)
37 }
38 
39 #[test]
sort()40 fn sort() {
41     let rng = seeded_rng();
42     let mut data: Vec<u32> = rng.sample_iter(&Standard).take(6 * 1024).collect();
43     let mut sorted_data = data.clone();
44     sorted_data.sort();
45     quick_sort(&mut data);
46     assert_eq!(data, sorted_data);
47 }
48 
49 #[test]
50 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
sort_in_pool()51 fn sort_in_pool() {
52     let rng = seeded_rng();
53     let mut data: Vec<u32> = rng.sample_iter(&Standard).take(12 * 1024).collect();
54 
55     let pool = ThreadPoolBuilder::new().build().unwrap();
56     let mut sorted_data = data.clone();
57     sorted_data.sort();
58     pool.install(|| quick_sort(&mut data));
59     assert_eq!(data, sorted_data);
60 }
61 
62 #[test]
63 #[should_panic(expected = "Hello, world!")]
panic_propagate_a()64 fn panic_propagate_a() {
65     join(|| panic!("Hello, world!"), || ());
66 }
67 
68 #[test]
69 #[should_panic(expected = "Hello, world!")]
panic_propagate_b()70 fn panic_propagate_b() {
71     join(|| (), || panic!("Hello, world!"));
72 }
73 
74 #[test]
75 #[should_panic(expected = "Hello, world!")]
panic_propagate_both()76 fn panic_propagate_both() {
77     join(|| panic!("Hello, world!"), || panic!("Goodbye, world!"));
78 }
79 
80 #[test]
81 #[cfg_attr(not(panic = "unwind"), ignore)]
panic_b_still_executes()82 fn panic_b_still_executes() {
83     let mut x = false;
84     match unwind::halt_unwinding(|| join(|| panic!("Hello, world!"), || x = true)) {
85         Ok(_) => panic!("failed to propagate panic from closure A,"),
86         Err(_) => assert!(x, "closure b failed to execute"),
87     }
88 }
89 
90 #[test]
91 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
join_context_both()92 fn join_context_both() {
93     // If we're not in a pool, both should be marked stolen as they're injected.
94     let (a_migrated, b_migrated) = join_context(|a| a.migrated(), |b| b.migrated());
95     assert!(a_migrated);
96     assert!(b_migrated);
97 }
98 
99 #[test]
100 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
join_context_neither()101 fn join_context_neither() {
102     // If we're already in a 1-thread pool, neither job should be stolen.
103     let pool = ThreadPoolBuilder::new().num_threads(1).build().unwrap();
104     let (a_migrated, b_migrated) =
105         pool.install(|| join_context(|a| a.migrated(), |b| b.migrated()));
106     assert!(!a_migrated);
107     assert!(!b_migrated);
108 }
109 
110 #[test]
111 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
join_context_second()112 fn join_context_second() {
113     use std::sync::Barrier;
114 
115     // If we're already in a 2-thread pool, the second job should be stolen.
116     let barrier = Barrier::new(2);
117     let pool = ThreadPoolBuilder::new().num_threads(2).build().unwrap();
118     let (a_migrated, b_migrated) = pool.install(|| {
119         join_context(
120             |a| {
121                 barrier.wait();
122                 a.migrated()
123             },
124             |b| {
125                 barrier.wait();
126                 b.migrated()
127             },
128         )
129     });
130     assert!(!a_migrated);
131     assert!(b_migrated);
132 }
133 
134 #[test]
135 #[cfg_attr(any(target_os = "emscripten", target_family = "wasm"), ignore)]
join_counter_overflow()136 fn join_counter_overflow() {
137     const MAX: u32 = 500_000;
138 
139     let mut i = 0;
140     let mut j = 0;
141     let pool = ThreadPoolBuilder::new().num_threads(2).build().unwrap();
142 
143     // Hammer on join a bunch of times -- used to hit overflow debug-assertions
144     // in JEC on 32-bit targets: https://github.com/rayon-rs/rayon/issues/797
145     for _ in 0..MAX {
146         pool.join(|| i += 1, || j += 1);
147     }
148 
149     assert_eq!(i, MAX);
150     assert_eq!(j, MAX);
151 }
152