1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::{
6 cmp::Reverse,
7 collections::{BTreeMap, VecDeque},
8 future::{pending, Future},
9 num::Wrapping,
10 sync::Arc,
11 task::{self, Poll, Waker},
12 thread::{self, ThreadId},
13 time::{Duration, Instant},
14 };
15
16 use anyhow::Result;
17 use async_task::{Runnable, Task};
18 use futures::{pin_mut, task::WakerRef};
19 use once_cell::unsync::Lazy;
20 use smallvec::SmallVec;
21 use sync::Mutex;
22
23 use crate::{enter::enter, sys, BlockingPool};
24
25 thread_local! (static LOCAL_CONTEXT: Lazy<Arc<Mutex<Context>>> = Lazy::new(Default::default));
26
27 #[derive(Default)]
28 struct Context {
29 queue: VecDeque<Runnable>,
30 timers: BTreeMap<Reverse<Instant>, SmallVec<[Waker; 2]>>,
31 waker: Option<Waker>,
32 }
33
34 #[derive(Default)]
35 struct Shared {
36 queue: VecDeque<Runnable>,
37 idle_workers: VecDeque<(ThreadId, Waker)>,
38 blocking_pool: BlockingPool,
39 }
40
add_timer(deadline: Instant, waker: &Waker)41 pub(crate) fn add_timer(deadline: Instant, waker: &Waker) {
42 LOCAL_CONTEXT.with(|local_ctx| {
43 let mut ctx = local_ctx.lock();
44 let wakers = ctx.timers.entry(Reverse(deadline)).or_default();
45 if wakers.iter().all(|w| !w.will_wake(waker)) {
46 wakers.push(waker.clone());
47 }
48 });
49 }
50
51 /// An executor for scheduling tasks that poll futures to completion.
52 ///
53 /// All asynchronous operations must run within an executor, which is capable of spawning futures as
54 /// tasks. This executor also provides a mechanism for performing asynchronous I/O operations.
55 ///
56 /// The returned type is a cheap, clonable handle to the underlying executor. Cloning it will only
57 /// create a new reference, not a new executor.
58 ///
59 /// # Examples
60 ///
61 /// Concurrently wait for multiple files to become readable/writable and then read/write the data.
62 ///
63 /// ```
64 /// use std::{
65 /// cmp::min,
66 /// convert::TryFrom,
67 /// fs::OpenOptions,
68 /// };
69 ///
70 /// use anyhow::Result;
71 /// use cros_async::{Executor, File};
72 /// use futures::future::join3;
73 ///
74 /// const CHUNK_SIZE: usize = 32;
75 ///
76 /// // Transfer `len` bytes of data from `from` to `to`.
77 /// async fn transfer_data(from: File, to: File, len: usize) -> Result<usize> {
78 /// let mut rem = len;
79 /// let mut buf = [0u8; CHUNK_SIZE];
80 /// while rem > 0 {
81 /// let count = from.read(&mut buf, None).await?;
82 ///
83 /// if count == 0 {
84 /// // End of file. Return the number of bytes transferred.
85 /// return Ok(len - rem);
86 /// }
87 ///
88 /// to.write_all(&buf[..count], None).await?;
89 ///
90 /// rem = rem.saturating_sub(count);
91 /// }
92 ///
93 /// Ok(len)
94 /// }
95 ///
96 /// # fn do_it() -> Result<()> {
97 /// let (rx, tx) = sys_util::pipe(true)?;
98 /// let zero = File::open("/dev/zero")?;
99 /// let zero_bytes = CHUNK_SIZE * 7;
100 /// let zero_to_pipe = transfer_data(
101 /// zero,
102 /// File::try_from(tx.try_clone()?)?,
103 /// zero_bytes,
104 /// );
105 ///
106 /// let rand = File::open("/dev/urandom")?;
107 /// let rand_bytes = CHUNK_SIZE * 19;
108 /// let rand_to_pipe = transfer_data(
109 /// rand,
110 /// File::try_from(tx)?,
111 /// rand_bytes
112 /// );
113 ///
114 /// let null = OpenOptions::new().write(true).open("/dev/null")?;
115 /// let null_bytes = zero_bytes + rand_bytes;
116 /// let pipe_to_null = transfer_data(
117 /// File::try_from(rx)?,
118 /// File::try_from(null)?,
119 /// null_bytes
120 /// );
121 ///
122 /// Executor::new().run_until(join3(
123 /// async { assert_eq!(pipe_to_null.await.unwrap(), null_bytes) },
124 /// async { assert_eq!(zero_to_pipe.await.unwrap(), zero_bytes) },
125 /// async { assert_eq!(rand_to_pipe.await.unwrap(), rand_bytes) },
126 /// ))?;
127 ///
128 /// # Ok(())
129 /// # }
130 ///
131 /// # do_it().unwrap();
132 /// ```
133 #[derive(Clone, Default)]
134 pub struct Executor {
135 shared: Arc<Mutex<Shared>>,
136 }
137
138 impl Executor {
139 /// Create a new `Executor`.
new() -> Executor140 pub fn new() -> Executor {
141 Default::default()
142 }
143
144 /// Spawn a new future for this executor to run to completion. Callers may use the returned
145 /// `Task` to await on the result of `f`. Dropping the returned `Task` will cancel `f`,
146 /// preventing it from being polled again. To drop a `Task` without canceling the future
147 /// associated with it use [`Task::detach`]. To cancel a task gracefully and wait until it is
148 /// fully destroyed, use [`Task::cancel`].
149 ///
150 /// # Examples
151 ///
152 /// ```
153 /// # use anyhow::Result;
154 /// # fn example_spawn() -> Result<()> {
155 /// # use std::thread;
156 /// #
157 /// # use cros_async::Executor;
158 /// #
159 /// # let ex = Executor::new();
160 /// #
161 /// # // Spawn a thread that runs the executor.
162 /// # let ex2 = ex.clone();
163 /// # thread::spawn(move || ex2.run());
164 /// #
165 /// let task = ex.spawn(async { 7 + 13 });
166 ///
167 /// let result = ex.run_until(task)?;
168 /// assert_eq!(result, 20);
169 /// # Ok(())
170 /// # }
171 /// #
172 /// # example_spawn().unwrap();
173 /// ```
spawn<F>(&self, f: F) -> Task<F::Output> where F: Future + Send + 'static, F::Output: Send + 'static,174 pub fn spawn<F>(&self, f: F) -> Task<F::Output>
175 where
176 F: Future + Send + 'static,
177 F::Output: Send + 'static,
178 {
179 let weak_shared = Arc::downgrade(&self.shared);
180 let schedule = move |runnable| {
181 if let Some(shared) = weak_shared.upgrade() {
182 let waker = {
183 let mut s = shared.lock();
184 s.queue.push_back(runnable);
185 s.idle_workers.pop_front()
186 };
187
188 if let Some((_, w)) = waker {
189 w.wake();
190 }
191 }
192 };
193 let (runnable, task) = async_task::spawn(f, schedule);
194 runnable.schedule();
195 task
196 }
197
198 /// Spawn a thread-local task for this executor to drive to completion. Like `spawn` but without
199 /// requiring `Send` on `F` or `F::Output`. This method should only be called from the same
200 /// thread where `run()` or `run_until()` is called.
201 ///
202 /// # Panics
203 ///
204 /// `Executor::run` and `Executor::run_util` will panic if they try to poll a future that was
205 /// added by calling `spawn_local` from a different thread.
206 ///
207 /// # Examples
208 ///
209 /// ```
210 /// # use anyhow::Result;
211 /// # fn example_spawn_local() -> Result<()> {
212 /// # use cros_async::Executor;
213 /// #
214 /// # let ex = Executor::new();
215 /// #
216 /// let task = ex.spawn_local(async { 7 + 13 });
217 ///
218 /// let result = ex.run_until(task)?;
219 /// assert_eq!(result, 20);
220 /// # Ok(())
221 /// # }
222 /// #
223 /// # example_spawn_local().unwrap();
224 /// ```
spawn_local<F>(&self, f: F) -> Task<F::Output> where F: Future + 'static, F::Output: 'static,225 pub fn spawn_local<F>(&self, f: F) -> Task<F::Output>
226 where
227 F: Future + 'static,
228 F::Output: 'static,
229 {
230 let weak_ctx = LOCAL_CONTEXT.with(|ctx| Arc::downgrade(ctx));
231 let schedule = move |runnable| {
232 if let Some(local_ctx) = weak_ctx.upgrade() {
233 let waker = {
234 let mut ctx = local_ctx.lock();
235 ctx.queue.push_back(runnable);
236 ctx.waker.take()
237 };
238
239 if let Some(w) = waker {
240 w.wake();
241 }
242 }
243 };
244 let (runnable, task) = async_task::spawn_local(f, schedule);
245 runnable.schedule();
246 task
247 }
248
249 /// Run the provided closure on a dedicated thread where blocking is allowed.
250 ///
251 /// Callers may `await` on the returned `Task` to wait for the result of `f`. Dropping or
252 /// canceling the returned `Task` may not cancel the operation if it was already started on a
253 /// worker thread.
254 ///
255 /// # Panics
256 ///
257 /// `await`ing the `Task` after the `Executor` is dropped will panic if the work was not already
258 /// completed.
259 ///
260 /// # Examples
261 ///
262 /// ```edition2018
263 /// # use cros_async::Executor;
264 /// #
265 /// # async fn do_it(ex: &Executor) {
266 /// let res = ex.spawn_blocking(move || {
267 /// // Do some CPU-intensive or blocking work here.
268 ///
269 /// 42
270 /// }).await;
271 ///
272 /// assert_eq!(res, 42);
273 /// # }
274 /// #
275 /// # let ex = Executor::new();
276 /// # ex.run_until(do_it(&ex)).unwrap();
277 /// ```
spawn_blocking<F, R>(&self, f: F) -> Task<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,278 pub fn spawn_blocking<F, R>(&self, f: F) -> Task<R>
279 where
280 F: FnOnce() -> R + Send + 'static,
281 R: Send + 'static,
282 {
283 self.shared.lock().blocking_pool.spawn(f)
284 }
285
286 /// Run the executor indefinitely, driving all spawned futures to completion. This method will
287 /// block the current thread and only return in the case of an error.
288 ///
289 /// # Examples
290 ///
291 /// ```
292 /// # use anyhow::Result;
293 /// # fn example_run() -> Result<()> {
294 /// use std::thread;
295 ///
296 /// use cros_async::Executor;
297 ///
298 /// let ex = Executor::new();
299 ///
300 /// // Spawn a thread that runs the executor.
301 /// let ex2 = ex.clone();
302 /// thread::spawn(move || ex2.run());
303 ///
304 /// let task = ex.spawn(async { 7 + 13 });
305 ///
306 /// let result = ex.run_until(task)?;
307 /// assert_eq!(result, 20);
308 /// # Ok(())
309 /// # }
310 /// #
311 /// # example_run().unwrap();
312 /// ```
313 #[inline]
run(&self) -> Result<()>314 pub fn run(&self) -> Result<()> {
315 self.run_until(pending())
316 }
317
318 /// Drive all futures spawned in this executor until `f` completes. This method will block the
319 /// current thread only until `f` is complete and there may still be unfinished futures in the
320 /// executor.
321 ///
322 /// # Examples
323 ///
324 /// ```
325 /// # use anyhow::Result;
326 /// # fn example_run_until() -> Result<()> {
327 /// use cros_async::Executor;
328 ///
329 /// let ex = Executor::new();
330 ///
331 /// let task = ex.spawn_local(async { 7 + 13 });
332 ///
333 /// let result = ex.run_until(task)?;
334 /// assert_eq!(result, 20);
335 /// # Ok(())
336 /// # }
337 /// #
338 /// # example_run_until().unwrap();
339 /// ```
run_until<F: Future>(&self, done: F) -> Result<F::Output>340 pub fn run_until<F: Future>(&self, done: F) -> Result<F::Output> {
341 // Prevent nested execution.
342 let _guard = enter()?;
343
344 pin_mut!(done);
345
346 let current_thread = thread::current().id();
347 let state = sys::platform_state()?;
348 let waker = state.waker_ref();
349 let mut cx = task::Context::from_waker(&waker);
350 let mut done_polled = false;
351
352 LOCAL_CONTEXT.with(|local_ctx| {
353 let next_local = || local_ctx.lock().queue.pop_front();
354 let next_global = || self.shared.lock().queue.pop_front();
355
356 let mut tick = Wrapping(0u32);
357
358 loop {
359 tick += Wrapping(1);
360
361 // If there are always tasks available to run in either the local or the global
362 // queue then we may go a long time without fetching completed events from the
363 // underlying platform driver. Poll the driver once in a while to prevent this from
364 // happening.
365 if tick.0 % 31 == 0 {
366 // A zero timeout will fetch new events without blocking.
367 self.get_events(&state, Some(Duration::from_millis(0)))?;
368 }
369
370 let was_woken = state.start_processing();
371 if was_woken || !done_polled {
372 done_polled = true;
373 if let Poll::Ready(v) = done.as_mut().poll(&mut cx) {
374 return Ok(v);
375 }
376 }
377
378 // If there are always tasks in the local queue then any tasks in the global queue
379 // will get starved. Pull tasks out of the global queue every once in a while even
380 // when there are still local tasks available to prevent this.
381 let next_runnable = if tick.0 % 13 == 0 {
382 next_global().or_else(next_local)
383 } else {
384 next_local().or_else(next_global)
385 };
386
387 if let Some(runnable) = next_runnable {
388 runnable.run();
389 continue;
390 }
391
392 // We're about to block so first check that new tasks have not snuck in and set the
393 // waker so that we can be woken up when tasks are re-scheduled.
394 let deadline = {
395 let mut ctx = local_ctx.lock();
396 if !ctx.queue.is_empty() {
397 // Some more tasks managed to sneak in. Go back to the start of the loop.
398 continue;
399 }
400
401 // There are no more tasks to run so set the waker.
402 if ctx.waker.is_none() {
403 ctx.waker = Some(cx.waker().clone());
404 }
405
406 // TODO: Replace with `last_entry` once it is stabilized.
407 ctx.timers.keys().next_back().cloned()
408 };
409 {
410 let mut shared = self.shared.lock();
411 if !shared.queue.is_empty() {
412 // More tasks were added to the global queue. Go back to the start of the loop.
413 continue;
414 }
415
416 // We're going to block so add ourselves to the idle worker list.
417 shared
418 .idle_workers
419 .push_back((current_thread, cx.waker().clone()));
420 };
421
422 // Now wait to be woken up.
423 let timeout = deadline.map(|d| d.0.saturating_duration_since(Instant::now()));
424 self.get_events(&state, timeout)?;
425
426 // Remove from idle workers.
427 {
428 let mut shared = self.shared.lock();
429 if let Some(idx) = shared
430 .idle_workers
431 .iter()
432 .position(|(id, _)| id == ¤t_thread)
433 {
434 shared.idle_workers.swap_remove_back(idx);
435 }
436 }
437
438 // Reset the ticks since we just fetched new events from the platform driver.
439 tick = Wrapping(0);
440 }
441 })
442 }
443
get_events<S: PlatformState>( &self, state: &S, timeout: Option<Duration>, ) -> anyhow::Result<()>444 fn get_events<S: PlatformState>(
445 &self,
446 state: &S,
447 timeout: Option<Duration>,
448 ) -> anyhow::Result<()> {
449 state.wait(timeout)?;
450
451 // Timer maintenance.
452 let expired = LOCAL_CONTEXT.with(|local_ctx| {
453 let mut ctx = local_ctx.lock();
454 let now = Instant::now();
455 ctx.timers.split_off(&Reverse(now))
456 });
457
458 // We cannot wake the timers while holding the lock because the schedule function for the
459 // task that's waiting on the timer may try to acquire the lock.
460 for (deadline, wakers) in expired {
461 debug_assert!(deadline.0 <= Instant::now());
462 for w in wakers {
463 w.wake();
464 }
465 }
466
467 Ok(())
468 }
469 }
470
471 // A trait that represents any thread-local platform-specific state that needs to be held on behalf
472 // of the `Executor`.
473 pub(crate) trait PlatformState {
474 // Indicates that the `Executor` is about to start processing futures that have been woken up.
475 //
476 // Implementations may use this as an indicator to skip unnecessary work when new tasks are
477 // woken up as the `Executor` will eventually get around to processing them on its own.
478 //
479 // `start_processing` must return true if one or more futures were woken up since the last call
480 // to `start_processing`. Otherwise it may return false.
start_processing(&self) -> bool481 fn start_processing(&self) -> bool;
482
483 // Returns a `WakerRef` that can be used to wake up the current thread.
waker_ref(&self) -> WakerRef484 fn waker_ref(&self) -> WakerRef;
485
486 // Waits for one or more futures to be woken up.
487 //
488 // This method should check with the underlying OS if any asynchronous IO operations have
489 // completed and then wake up the associated futures.
490 //
491 // If `timeout` is provided then this method should block until either one or more futures are
492 // woken up or the timeout duration elapses. If `timeout` has a zero duration then this method
493 // should fetch completed asynchronous IO operations and then immediately return.
494 //
495 // If `timeout` is not provided then this method should block until one or more futures are
496 // woken up.
wait(&self, timeout: Option<Duration>) -> anyhow::Result<()>497 fn wait(&self, timeout: Option<Duration>) -> anyhow::Result<()>;
498 }
499
500 #[cfg(test)]
501 mod test {
502 use super::*;
503
504 use std::{
505 convert::TryFrom,
506 fs::OpenOptions,
507 mem,
508 pin::Pin,
509 thread::{self, JoinHandle},
510 time::Instant,
511 };
512
513 use futures::{
514 channel::{mpsc, oneshot},
515 future::{join3, select, Either},
516 sink::SinkExt,
517 stream::{self, FuturesUnordered, StreamExt},
518 };
519
520 use crate::{File, OwnedIoBuf};
521
522 #[test]
basic()523 fn basic() {
524 async fn do_it() {
525 let (r, _w) = sys_util::pipe(true).unwrap();
526 let done = async { 5usize };
527
528 let rx = File::try_from(r).unwrap();
529 let mut buf = 0u64.to_ne_bytes();
530 let pending = rx.read(&mut buf, None);
531 pin_mut!(pending, done);
532
533 match select(pending, done).await {
534 Either::Right((5, pending)) => drop(pending),
535 _ => panic!("unexpected select result"),
536 }
537 }
538
539 Executor::new().run_until(do_it()).unwrap();
540 }
541
542 #[derive(Default)]
543 struct QuitShared {
544 wakers: Vec<task::Waker>,
545 should_quit: bool,
546 }
547
548 #[derive(Clone, Default)]
549 struct Quit {
550 shared: Arc<Mutex<QuitShared>>,
551 }
552
553 impl Quit {
quit(self)554 fn quit(self) {
555 let wakers = {
556 let mut shared = self.shared.lock();
557 shared.should_quit = true;
558 mem::take(&mut shared.wakers)
559 };
560
561 for w in wakers {
562 w.wake();
563 }
564 }
565 }
566
567 impl Future for Quit {
568 type Output = ();
569
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>570 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
571 let mut shared = self.shared.lock();
572 if shared.should_quit {
573 return Poll::Ready(());
574 }
575
576 if shared.wakers.iter().all(|w| !w.will_wake(cx.waker())) {
577 shared.wakers.push(cx.waker().clone());
578 }
579
580 Poll::Pending
581 }
582 }
583
584 #[test]
outer_future_is_send()585 fn outer_future_is_send() {
586 const NUM_THREADS: usize = 3;
587 const CHUNK_SIZE: usize = 32;
588
589 async fn read_iobuf(
590 ex: &Executor,
591 f: File,
592 buf: OwnedIoBuf,
593 ) -> (anyhow::Result<usize>, OwnedIoBuf, File) {
594 let (tx, rx) = oneshot::channel();
595 ex.spawn_local(async move {
596 let (res, buf) = f.read_iobuf(buf, None).await;
597 let _ = tx.send((res, buf, f));
598 })
599 .detach();
600 rx.await.unwrap()
601 }
602
603 async fn write_iobuf(
604 ex: &Executor,
605 f: File,
606 buf: OwnedIoBuf,
607 ) -> (anyhow::Result<usize>, OwnedIoBuf, File) {
608 let (tx, rx) = oneshot::channel();
609 ex.spawn_local(async move {
610 let (res, buf) = f.write_iobuf(buf, None).await;
611 let _ = tx.send((res, buf, f));
612 })
613 .detach();
614 rx.await.unwrap()
615 }
616
617 async fn transfer_data(
618 ex: Executor,
619 mut from: File,
620 mut to: File,
621 len: usize,
622 ) -> Result<usize> {
623 let mut rem = len;
624 let mut buf = OwnedIoBuf::new(vec![0xa2u8; CHUNK_SIZE]);
625 while rem > 0 {
626 let (res, data, f) = read_iobuf(&ex, from, buf).await;
627 let count = res?;
628 buf = data;
629 from = f;
630 if count == 0 {
631 // End of file. Return the number of bytes transferred.
632 return Ok(len - rem);
633 }
634 assert_eq!(count, CHUNK_SIZE);
635
636 let (res, data, t) = write_iobuf(&ex, to, buf).await;
637 let count = res?;
638 buf = data;
639 to = t;
640 assert_eq!(count, CHUNK_SIZE);
641
642 rem = rem.saturating_sub(count);
643 }
644
645 Ok(len)
646 }
647
648 fn do_it() -> anyhow::Result<()> {
649 let ex = Executor::new();
650 let (rx, tx) = sys_util::pipe(true)?;
651 let zero = File::open("/dev/zero")?;
652 let zero_bytes = CHUNK_SIZE * 7;
653 let zero_to_pipe = ex.spawn(transfer_data(
654 ex.clone(),
655 zero,
656 File::try_from(tx.try_clone()?)?,
657 zero_bytes,
658 ));
659
660 let rand = File::open("/dev/urandom")?;
661 let rand_bytes = CHUNK_SIZE * 19;
662 let rand_to_pipe = ex.spawn(transfer_data(
663 ex.clone(),
664 rand,
665 File::try_from(tx)?,
666 rand_bytes,
667 ));
668
669 let null = OpenOptions::new().write(true).open("/dev/null")?;
670 let null_bytes = zero_bytes + rand_bytes;
671 let pipe_to_null = ex.spawn(transfer_data(
672 ex.clone(),
673 File::try_from(rx)?,
674 File::try_from(null)?,
675 null_bytes,
676 ));
677
678 let mut threads = Vec::with_capacity(NUM_THREADS);
679 let quit = Quit::default();
680 for _ in 0..NUM_THREADS {
681 let thread_ex = ex.clone();
682 let thread_quit = quit.clone();
683 threads.push(thread::spawn(move || thread_ex.run_until(thread_quit)))
684 }
685 ex.run_until(join3(
686 async { assert_eq!(pipe_to_null.await.unwrap(), null_bytes) },
687 async { assert_eq!(zero_to_pipe.await.unwrap(), zero_bytes) },
688 async { assert_eq!(rand_to_pipe.await.unwrap(), rand_bytes) },
689 ))?;
690
691 quit.quit();
692 for t in threads {
693 t.join().unwrap().unwrap();
694 }
695
696 Ok(())
697 }
698
699 do_it().unwrap();
700 }
701
702 #[test]
thread_pool()703 fn thread_pool() {
704 const NUM_THREADS: usize = 8;
705 const NUM_CHANNELS: usize = 19;
706 const NUM_ITERATIONS: usize = 71;
707
708 let ex = Executor::new();
709
710 let tasks = FuturesUnordered::new();
711 let (mut tx, mut rx) = mpsc::channel(10);
712 tasks.push(ex.spawn(async move {
713 for i in 0..NUM_ITERATIONS {
714 tx.send(i).await?;
715 }
716
717 Ok::<(), anyhow::Error>(())
718 }));
719
720 for _ in 0..NUM_CHANNELS {
721 let (mut task_tx, task_rx) = mpsc::channel(10);
722 tasks.push(ex.spawn(async move {
723 while let Some(v) = rx.next().await {
724 task_tx.send(v).await?;
725 }
726
727 Ok::<(), anyhow::Error>(())
728 }));
729
730 rx = task_rx;
731 }
732
733 tasks.push(ex.spawn(async move {
734 let mut zip = rx.zip(stream::iter(0..NUM_ITERATIONS));
735 while let Some((l, r)) = zip.next().await {
736 assert_eq!(l, r);
737 }
738
739 Ok::<(), anyhow::Error>(())
740 }));
741
742 let quit = Quit::default();
743 let mut threads = Vec::with_capacity(NUM_THREADS);
744 for _ in 0..NUM_THREADS {
745 let thread_ex = ex.clone();
746 let thread_quit = quit.clone();
747 threads.push(thread::spawn(move || thread_ex.run_until(thread_quit)));
748 }
749
750 let results = ex
751 .run_until(tasks.collect::<Vec<anyhow::Result<()>>>())
752 .unwrap();
753 results
754 .into_iter()
755 .collect::<anyhow::Result<Vec<()>>>()
756 .unwrap();
757
758 quit.quit();
759 for t in threads {
760 t.join().unwrap().unwrap();
761 }
762 }
763
764 // Sends a message on `tx` once there is an idle worker in `Executor` or 5 seconds have passed.
765 // Sends true if this function observed an idle worker and false otherwise.
notify_on_idle_worker(ex: Executor, tx: oneshot::Sender<bool>)766 fn notify_on_idle_worker(ex: Executor, tx: oneshot::Sender<bool>) {
767 let deadline = Instant::now() + Duration::from_secs(5);
768 while Instant::now() < deadline {
769 // Wait for the main thread to add itself to the idle worker list.
770 if !ex.shared.lock().idle_workers.is_empty() {
771 break;
772 }
773
774 thread::sleep(Duration::from_millis(10));
775 }
776
777 if Instant::now() <= deadline {
778 tx.send(true).unwrap();
779 } else {
780 tx.send(false).unwrap();
781 }
782 }
783
784 #[test]
wakeup_run_until()785 fn wakeup_run_until() {
786 let (tx, rx) = oneshot::channel();
787
788 let ex = Executor::new();
789
790 let thread_ex = ex.clone();
791 let waker_thread = thread::spawn(move || notify_on_idle_worker(thread_ex, tx));
792
793 // Since we're using `run_until` the wakeup path won't use the regular scheduling functions.
794 let success = ex.run_until(rx).unwrap().unwrap();
795 assert!(success);
796 assert!(ex.shared.lock().idle_workers.is_empty());
797
798 waker_thread.join().unwrap();
799 }
800
801 #[test]
wakeup_local_task()802 fn wakeup_local_task() {
803 let (tx, rx) = oneshot::channel();
804
805 let ex = Executor::new();
806
807 let thread_ex = ex.clone();
808 let waker_thread = thread::spawn(move || notify_on_idle_worker(thread_ex, tx));
809
810 // By using `spawn_local`, the wakeup path will go via LOCAL_CTX.
811 let task = ex.spawn_local(rx);
812 let success = ex.run_until(task).unwrap().unwrap();
813 assert!(success);
814 assert!(ex.shared.lock().idle_workers.is_empty());
815
816 waker_thread.join().unwrap();
817 }
818
819 #[test]
wakeup_global_task()820 fn wakeup_global_task() {
821 let (tx, rx) = oneshot::channel();
822
823 let ex = Executor::new();
824
825 let thread_ex = ex.clone();
826 let waker_thread = thread::spawn(move || notify_on_idle_worker(thread_ex, tx));
827
828 // By using `spawn`, the wakeup path will go via `ex.shared`.
829 let task = ex.spawn(rx);
830 let success = ex.run_until(task).unwrap().unwrap();
831 assert!(success);
832 assert!(ex.shared.lock().idle_workers.is_empty());
833
834 waker_thread.join().unwrap();
835 }
836
837 #[test]
wake_up_correct_worker()838 fn wake_up_correct_worker() {
839 struct ThreadData {
840 id: ThreadId,
841 sender: mpsc::Sender<()>,
842 handle: JoinHandle<anyhow::Result<()>>,
843 }
844
845 const NUM_THREADS: usize = 7;
846 const NUM_ITERATIONS: usize = 119;
847
848 let ex = Executor::new();
849
850 let (tx, mut rx) = mpsc::channel(0);
851 let mut threads = Vec::with_capacity(NUM_THREADS);
852 for _ in 0..NUM_THREADS {
853 let (sender, mut receiver) = mpsc::channel(0);
854 let mut thread_tx = tx.clone();
855 let thread_ex = ex.clone();
856 let handle = thread::spawn(move || {
857 let id = thread::current().id();
858 thread_ex
859 .run_until(async move {
860 while let Some(()) = receiver.next().await {
861 thread_tx.send(id).await?;
862 }
863
864 Ok(())
865 })
866 .unwrap()
867 });
868
869 let id = handle.thread().id();
870 threads.push(ThreadData { id, sender, handle });
871 }
872
873 ex.run_until(async {
874 for i in 0..NUM_ITERATIONS {
875 let data = &mut threads[i % NUM_THREADS];
876 data.sender.send(()).await?;
877 assert_eq!(rx.next().await.unwrap(), data.id);
878 }
879
880 Ok::<(), anyhow::Error>(())
881 })
882 .unwrap()
883 .unwrap();
884
885 for t in threads {
886 let ThreadData { id, sender, handle } = t;
887
888 // Dropping the sender will close the channel and cause the thread to exit.
889 drop((id, sender));
890 handle.join().unwrap().unwrap();
891 }
892 }
893 }
894