• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #![warn(rust_2018_idioms)]
2 #![cfg(not(target_os = "wasi"))] // Wasi doesn't support threads
3 
4 use std::rc::Rc;
5 use std::sync::Arc;
6 use tokio::sync::Barrier;
7 use tokio_util::task;
8 
9 /// Simple test of running a !Send future via spawn_pinned
10 #[tokio::test]
can_spawn_not_send_future()11 async fn can_spawn_not_send_future() {
12     let pool = task::LocalPoolHandle::new(1);
13 
14     let output = pool
15         .spawn_pinned(|| {
16             // Rc is !Send + !Sync
17             let local_data = Rc::new("test");
18 
19             // This future holds an Rc, so it is !Send
20             async move { local_data.to_string() }
21         })
22         .await
23         .unwrap();
24 
25     assert_eq!(output, "test");
26 }
27 
28 /// Dropping the join handle still lets the task execute
29 #[test]
can_drop_future_and_still_get_output()30 fn can_drop_future_and_still_get_output() {
31     let pool = task::LocalPoolHandle::new(1);
32     let (sender, receiver) = std::sync::mpsc::channel();
33 
34     let _ = pool.spawn_pinned(move || {
35         // Rc is !Send + !Sync
36         let local_data = Rc::new("test");
37 
38         // This future holds an Rc, so it is !Send
39         async move {
40             let _ = sender.send(local_data.to_string());
41         }
42     });
43 
44     assert_eq!(receiver.recv(), Ok("test".to_string()));
45 }
46 
47 #[test]
48 #[should_panic(expected = "assertion failed: pool_size > 0")]
cannot_create_zero_sized_pool()49 fn cannot_create_zero_sized_pool() {
50     let _pool = task::LocalPoolHandle::new(0);
51 }
52 
53 /// We should be able to spawn multiple futures onto the pool at the same time.
54 #[tokio::test]
can_spawn_multiple_futures()55 async fn can_spawn_multiple_futures() {
56     let pool = task::LocalPoolHandle::new(2);
57 
58     let join_handle1 = pool.spawn_pinned(|| {
59         let local_data = Rc::new("test1");
60         async move { local_data.to_string() }
61     });
62     let join_handle2 = pool.spawn_pinned(|| {
63         let local_data = Rc::new("test2");
64         async move { local_data.to_string() }
65     });
66 
67     assert_eq!(join_handle1.await.unwrap(), "test1");
68     assert_eq!(join_handle2.await.unwrap(), "test2");
69 }
70 
71 /// A panic in the spawned task causes the join handle to return an error.
72 /// But, you can continue to spawn tasks.
73 #[tokio::test]
task_panic_propagates()74 async fn task_panic_propagates() {
75     let pool = task::LocalPoolHandle::new(1);
76 
77     let join_handle = pool.spawn_pinned(|| async {
78         panic!("Test panic");
79     });
80 
81     let result = join_handle.await;
82     assert!(result.is_err());
83     let error = result.unwrap_err();
84     assert!(error.is_panic());
85     let panic_str = error.into_panic().downcast::<&'static str>().unwrap();
86     assert_eq!(*panic_str, "Test panic");
87 
88     // Trying again with a "safe" task still works
89     let join_handle = pool.spawn_pinned(|| async { "test" });
90     let result = join_handle.await;
91     assert!(result.is_ok());
92     assert_eq!(result.unwrap(), "test");
93 }
94 
95 /// A panic during task creation causes the join handle to return an error.
96 /// But, you can continue to spawn tasks.
97 #[tokio::test]
callback_panic_does_not_kill_worker()98 async fn callback_panic_does_not_kill_worker() {
99     let pool = task::LocalPoolHandle::new(1);
100 
101     let join_handle = pool.spawn_pinned(|| {
102         panic!("Test panic");
103         #[allow(unreachable_code)]
104         async {}
105     });
106 
107     let result = join_handle.await;
108     assert!(result.is_err());
109     let error = result.unwrap_err();
110     assert!(error.is_panic());
111     let panic_str = error.into_panic().downcast::<&'static str>().unwrap();
112     assert_eq!(*panic_str, "Test panic");
113 
114     // Trying again with a "safe" callback works
115     let join_handle = pool.spawn_pinned(|| async { "test" });
116     let result = join_handle.await;
117     assert!(result.is_ok());
118     assert_eq!(result.unwrap(), "test");
119 }
120 
121 /// Canceling the task via the returned join handle cancels the spawned task
122 /// (which has a different, internal join handle).
123 #[tokio::test]
task_cancellation_propagates()124 async fn task_cancellation_propagates() {
125     let pool = task::LocalPoolHandle::new(1);
126     let notify_dropped = Arc::new(());
127     let weak_notify_dropped = Arc::downgrade(&notify_dropped);
128 
129     let (start_sender, start_receiver) = tokio::sync::oneshot::channel();
130     let (drop_sender, drop_receiver) = tokio::sync::oneshot::channel::<()>();
131     let join_handle = pool.spawn_pinned(|| async move {
132         let _drop_sender = drop_sender;
133         // Move the Arc into the task
134         let _notify_dropped = notify_dropped;
135         let _ = start_sender.send(());
136 
137         // Keep the task running until it gets aborted
138         futures::future::pending::<()>().await;
139     });
140 
141     // Wait for the task to start
142     let _ = start_receiver.await;
143 
144     join_handle.abort();
145 
146     // Wait for the inner task to abort, dropping the sender.
147     // The top level join handle aborts quicker than the inner task (the abort
148     // needs to propagate and get processed on the worker thread), so we can't
149     // just await the top level join handle.
150     let _ = drop_receiver.await;
151 
152     // Check that the Arc has been dropped. This verifies that the inner task
153     // was canceled as well.
154     assert!(weak_notify_dropped.upgrade().is_none());
155 }
156 
157 /// Tasks should be given to the least burdened worker. When spawning two tasks
158 /// on a pool with two empty workers the tasks should be spawned on separate
159 /// workers.
160 #[tokio::test]
tasks_are_balanced()161 async fn tasks_are_balanced() {
162     let pool = task::LocalPoolHandle::new(2);
163 
164     // Spawn a task so one thread has a task count of 1
165     let (start_sender1, start_receiver1) = tokio::sync::oneshot::channel();
166     let (end_sender1, end_receiver1) = tokio::sync::oneshot::channel();
167     let join_handle1 = pool.spawn_pinned(|| async move {
168         let _ = start_sender1.send(());
169         let _ = end_receiver1.await;
170         std::thread::current().id()
171     });
172 
173     // Wait for the first task to start up
174     let _ = start_receiver1.await;
175 
176     // This task should be spawned on the other thread
177     let (start_sender2, start_receiver2) = tokio::sync::oneshot::channel();
178     let join_handle2 = pool.spawn_pinned(|| async move {
179         let _ = start_sender2.send(());
180         std::thread::current().id()
181     });
182 
183     // Wait for the second task to start up
184     let _ = start_receiver2.await;
185 
186     // Allow the first task to end
187     let _ = end_sender1.send(());
188 
189     let thread_id1 = join_handle1.await.unwrap();
190     let thread_id2 = join_handle2.await.unwrap();
191 
192     // Since the first task was active when the second task spawned, they should
193     // be on separate workers/threads.
194     assert_ne!(thread_id1, thread_id2);
195 }
196 
197 #[tokio::test]
spawn_by_idx()198 async fn spawn_by_idx() {
199     let pool = task::LocalPoolHandle::new(3);
200     let barrier = Arc::new(Barrier::new(4));
201     let barrier1 = barrier.clone();
202     let barrier2 = barrier.clone();
203     let barrier3 = barrier.clone();
204 
205     let handle1 = pool.spawn_pinned_by_idx(
206         || async move {
207             barrier1.wait().await;
208             std::thread::current().id()
209         },
210         0,
211     );
212     let _ = pool.spawn_pinned_by_idx(
213         || async move {
214             barrier2.wait().await;
215             std::thread::current().id()
216         },
217         0,
218     );
219     let handle2 = pool.spawn_pinned_by_idx(
220         || async move {
221             barrier3.wait().await;
222             std::thread::current().id()
223         },
224         1,
225     );
226 
227     let loads = pool.get_task_loads_for_each_worker();
228     barrier.wait().await;
229     assert_eq!(loads[0], 2);
230     assert_eq!(loads[1], 1);
231     assert_eq!(loads[2], 0);
232 
233     let thread_id1 = handle1.await.unwrap();
234     let thread_id2 = handle2.await.unwrap();
235 
236     assert_ne!(thread_id1, thread_id2);
237 }
238