• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Tests for the join code.
2 
3 use crate::join::*;
4 use crate::unwind;
5 use crate::ThreadPoolBuilder;
6 use rand::distributions::Standard;
7 use rand::{Rng, SeedableRng};
8 use rand_xorshift::XorShiftRng;
9 
quick_sort<T: PartialOrd + Send>(v: &mut [T])10 fn quick_sort<T: PartialOrd + Send>(v: &mut [T]) {
11     if v.len() <= 1 {
12         return;
13     }
14 
15     let mid = partition(v);
16     let (lo, hi) = v.split_at_mut(mid);
17     join(|| quick_sort(lo), || quick_sort(hi));
18 }
19 
partition<T: PartialOrd + Send>(v: &mut [T]) -> usize20 fn partition<T: PartialOrd + Send>(v: &mut [T]) -> usize {
21     let pivot = v.len() - 1;
22     let mut i = 0;
23     for j in 0..pivot {
24         if v[j] <= v[pivot] {
25             v.swap(i, j);
26             i += 1;
27         }
28     }
29     v.swap(i, pivot);
30     i
31 }
32 
seeded_rng() -> XorShiftRng33 fn seeded_rng() -> XorShiftRng {
34     let mut seed = <XorShiftRng as SeedableRng>::Seed::default();
35     (0..).zip(seed.as_mut()).for_each(|(i, x)| *x = i);
36     XorShiftRng::from_seed(seed)
37 }
38 
39 #[test]
sort()40 fn sort() {
41     let rng = seeded_rng();
42     let mut data: Vec<u32> = rng.sample_iter(&Standard).take(6 * 1024).collect();
43     let mut sorted_data = data.clone();
44     sorted_data.sort();
45     quick_sort(&mut data);
46     assert_eq!(data, sorted_data);
47 }
48 
49 #[test]
sort_in_pool()50 fn sort_in_pool() {
51     let rng = seeded_rng();
52     let mut data: Vec<u32> = rng.sample_iter(&Standard).take(12 * 1024).collect();
53 
54     let pool = ThreadPoolBuilder::new().build().unwrap();
55     let mut sorted_data = data.clone();
56     sorted_data.sort();
57     pool.install(|| quick_sort(&mut data));
58     assert_eq!(data, sorted_data);
59 }
60 
61 #[test]
62 #[should_panic(expected = "Hello, world!")]
panic_propagate_a()63 fn panic_propagate_a() {
64     join(|| panic!("Hello, world!"), || ());
65 }
66 
67 #[test]
68 #[should_panic(expected = "Hello, world!")]
panic_propagate_b()69 fn panic_propagate_b() {
70     join(|| (), || panic!("Hello, world!"));
71 }
72 
73 #[test]
74 #[should_panic(expected = "Hello, world!")]
panic_propagate_both()75 fn panic_propagate_both() {
76     join(|| panic!("Hello, world!"), || panic!("Goodbye, world!"));
77 }
78 
79 #[test]
panic_b_still_executes()80 fn panic_b_still_executes() {
81     let mut x = false;
82     match unwind::halt_unwinding(|| join(|| panic!("Hello, world!"), || x = true)) {
83         Ok(_) => panic!("failed to propagate panic from closure A,"),
84         Err(_) => assert!(x, "closure b failed to execute"),
85     }
86 }
87 
88 #[test]
join_context_both()89 fn join_context_both() {
90     // If we're not in a pool, both should be marked stolen as they're injected.
91     let (a_migrated, b_migrated) = join_context(|a| a.migrated(), |b| b.migrated());
92     assert!(a_migrated);
93     assert!(b_migrated);
94 }
95 
96 #[test]
join_context_neither()97 fn join_context_neither() {
98     // If we're already in a 1-thread pool, neither job should be stolen.
99     let pool = ThreadPoolBuilder::new().num_threads(1).build().unwrap();
100     let (a_migrated, b_migrated) =
101         pool.install(|| join_context(|a| a.migrated(), |b| b.migrated()));
102     assert!(!a_migrated);
103     assert!(!b_migrated);
104 }
105 
106 #[test]
join_context_second()107 fn join_context_second() {
108     use std::sync::Barrier;
109 
110     // If we're already in a 2-thread pool, the second job should be stolen.
111     let barrier = Barrier::new(2);
112     let pool = ThreadPoolBuilder::new().num_threads(2).build().unwrap();
113     let (a_migrated, b_migrated) = pool.install(|| {
114         join_context(
115             |a| {
116                 barrier.wait();
117                 a.migrated()
118             },
119             |b| {
120                 barrier.wait();
121                 b.migrated()
122             },
123         )
124     });
125     assert!(!a_migrated);
126     assert!(b_migrated);
127 }
128 
129 #[test]
join_counter_overflow()130 fn join_counter_overflow() {
131     const MAX: u32 = 500_000;
132 
133     let mut i = 0;
134     let mut j = 0;
135     let pool = ThreadPoolBuilder::new().num_threads(2).build().unwrap();
136 
137     // Hammer on join a bunch of times -- used to hit overflow debug-assertions
138     // in JEC on 32-bit targets: https://github.com/rayon-rs/rayon/issues/797
139     for _ in 0..MAX {
140         pool.join(|| i += 1, || j += 1);
141     }
142 
143     assert_eq!(i, MAX);
144     assert_eq!(j, MAX);
145 }
146