• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2021 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::{
6     cmp::Reverse,
7     collections::{BTreeMap, VecDeque},
8     future::{pending, Future},
9     num::Wrapping,
10     sync::Arc,
11     task::{self, Poll, Waker},
12     thread::{self, ThreadId},
13     time::{Duration, Instant},
14 };
15 
16 use anyhow::Result;
17 use async_task::{Runnable, Task};
18 use futures::{pin_mut, task::WakerRef};
19 use once_cell::unsync::Lazy;
20 use smallvec::SmallVec;
21 use sync::Mutex;
22 
23 use crate::{enter::enter, sys, BlockingPool};
24 
25 thread_local! (static LOCAL_CONTEXT: Lazy<Arc<Mutex<Context>>> = Lazy::new(Default::default));
26 
27 #[derive(Default)]
28 struct Context {
29     queue: VecDeque<Runnable>,
30     timers: BTreeMap<Reverse<Instant>, SmallVec<[Waker; 2]>>,
31     waker: Option<Waker>,
32 }
33 
34 #[derive(Default)]
35 struct Shared {
36     queue: VecDeque<Runnable>,
37     idle_workers: VecDeque<(ThreadId, Waker)>,
38     blocking_pool: BlockingPool,
39 }
40 
add_timer(deadline: Instant, waker: &Waker)41 pub(crate) fn add_timer(deadline: Instant, waker: &Waker) {
42     LOCAL_CONTEXT.with(|local_ctx| {
43         let mut ctx = local_ctx.lock();
44         let wakers = ctx.timers.entry(Reverse(deadline)).or_default();
45         if wakers.iter().all(|w| !w.will_wake(waker)) {
46             wakers.push(waker.clone());
47         }
48     });
49 }
50 
51 /// An executor for scheduling tasks that poll futures to completion.
52 ///
53 /// All asynchronous operations must run within an executor, which is capable of spawning futures as
54 /// tasks. This executor also provides a mechanism for performing asynchronous I/O operations.
55 ///
56 /// The returned type is a cheap, clonable handle to the underlying executor. Cloning it will only
57 /// create a new reference, not a new executor.
58 ///
59 /// # Examples
60 ///
61 /// Concurrently wait for multiple files to become readable/writable and then read/write the data.
62 ///
63 /// ```
64 /// use std::{
65 ///     cmp::min,
66 ///     convert::TryFrom,
67 ///     fs::OpenOptions,
68 /// };
69 ///
70 /// use anyhow::Result;
71 /// use cros_async::{Executor, File};
72 /// use futures::future::join3;
73 ///
74 /// const CHUNK_SIZE: usize = 32;
75 ///
76 /// // Transfer `len` bytes of data from `from` to `to`.
77 /// async fn transfer_data(from: File, to: File, len: usize) -> Result<usize> {
78 ///     let mut rem = len;
79 ///     let mut buf = [0u8; CHUNK_SIZE];
80 ///     while rem > 0 {
81 ///         let count = from.read(&mut buf, None).await?;
82 ///
83 ///         if count == 0 {
84 ///             // End of file. Return the number of bytes transferred.
85 ///             return Ok(len - rem);
86 ///         }
87 ///
88 ///         to.write_all(&buf[..count], None).await?;
89 ///
90 ///         rem = rem.saturating_sub(count);
91 ///     }
92 ///
93 ///     Ok(len)
94 /// }
95 ///
96 /// # fn do_it() -> Result<()> {
97 ///     let (rx, tx) = sys_util::pipe(true)?;
98 ///     let zero = File::open("/dev/zero")?;
99 ///     let zero_bytes = CHUNK_SIZE * 7;
100 ///     let zero_to_pipe = transfer_data(
101 ///         zero,
102 ///         File::try_from(tx.try_clone()?)?,
103 ///         zero_bytes,
104 ///     );
105 ///
106 ///     let rand = File::open("/dev/urandom")?;
107 ///     let rand_bytes = CHUNK_SIZE * 19;
108 ///     let rand_to_pipe = transfer_data(
109 ///         rand,
110 ///         File::try_from(tx)?,
111 ///         rand_bytes
112 ///     );
113 ///
114 ///     let null = OpenOptions::new().write(true).open("/dev/null")?;
115 ///     let null_bytes = zero_bytes + rand_bytes;
116 ///     let pipe_to_null = transfer_data(
117 ///         File::try_from(rx)?,
118 ///         File::try_from(null)?,
119 ///         null_bytes
120 ///     );
121 ///
122 ///     Executor::new().run_until(join3(
123 ///         async { assert_eq!(pipe_to_null.await.unwrap(), null_bytes) },
124 ///         async { assert_eq!(zero_to_pipe.await.unwrap(), zero_bytes) },
125 ///         async { assert_eq!(rand_to_pipe.await.unwrap(), rand_bytes) },
126 ///     ))?;
127 ///
128 /// #     Ok(())
129 /// # }
130 ///
131 /// # do_it().unwrap();
132 /// ```
133 #[derive(Clone, Default)]
134 pub struct Executor {
135     shared: Arc<Mutex<Shared>>,
136 }
137 
138 impl Executor {
139     /// Create a new `Executor`.
new() -> Executor140     pub fn new() -> Executor {
141         Default::default()
142     }
143 
144     /// Spawn a new future for this executor to run to completion. Callers may use the returned
145     /// `Task` to await on the result of `f`. Dropping the returned `Task` will cancel `f`,
146     /// preventing it from being polled again. To drop a `Task` without canceling the future
147     /// associated with it use [`Task::detach`]. To cancel a task gracefully and wait until it is
148     /// fully destroyed, use [`Task::cancel`].
149     ///
150     /// # Examples
151     ///
152     /// ```
153     /// # use anyhow::Result;
154     /// # fn example_spawn() -> Result<()> {
155     /// #      use std::thread;
156     /// #
157     /// #      use cros_async::Executor;
158     /// #
159     /// #      let ex = Executor::new();
160     /// #
161     /// #      // Spawn a thread that runs the executor.
162     /// #      let ex2 = ex.clone();
163     /// #      thread::spawn(move || ex2.run());
164     /// #
165     ///       let task = ex.spawn(async { 7 + 13 });
166     ///
167     ///       let result = ex.run_until(task)?;
168     ///       assert_eq!(result, 20);
169     /// #     Ok(())
170     /// # }
171     /// #
172     /// # example_spawn().unwrap();
173     /// ```
spawn<F>(&self, f: F) -> Task<F::Output> where F: Future + Send + 'static, F::Output: Send + 'static,174     pub fn spawn<F>(&self, f: F) -> Task<F::Output>
175     where
176         F: Future + Send + 'static,
177         F::Output: Send + 'static,
178     {
179         let weak_shared = Arc::downgrade(&self.shared);
180         let schedule = move |runnable| {
181             if let Some(shared) = weak_shared.upgrade() {
182                 let waker = {
183                     let mut s = shared.lock();
184                     s.queue.push_back(runnable);
185                     s.idle_workers.pop_front()
186                 };
187 
188                 if let Some((_, w)) = waker {
189                     w.wake();
190                 }
191             }
192         };
193         let (runnable, task) = async_task::spawn(f, schedule);
194         runnable.schedule();
195         task
196     }
197 
198     /// Spawn a thread-local task for this executor to drive to completion. Like `spawn` but without
199     /// requiring `Send` on `F` or `F::Output`. This method should only be called from the same
200     /// thread where `run()` or `run_until()` is called.
201     ///
202     /// # Panics
203     ///
204     /// `Executor::run` and `Executor::run_util` will panic if they try to poll a future that was
205     /// added by calling `spawn_local` from a different thread.
206     ///
207     /// # Examples
208     ///
209     /// ```
210     /// # use anyhow::Result;
211     /// # fn example_spawn_local() -> Result<()> {
212     /// #      use cros_async::Executor;
213     /// #
214     /// #      let ex = Executor::new();
215     /// #
216     ///       let task = ex.spawn_local(async { 7 + 13 });
217     ///
218     ///       let result = ex.run_until(task)?;
219     ///       assert_eq!(result, 20);
220     /// #     Ok(())
221     /// # }
222     /// #
223     /// # example_spawn_local().unwrap();
224     /// ```
spawn_local<F>(&self, f: F) -> Task<F::Output> where F: Future + 'static, F::Output: 'static,225     pub fn spawn_local<F>(&self, f: F) -> Task<F::Output>
226     where
227         F: Future + 'static,
228         F::Output: 'static,
229     {
230         let weak_ctx = LOCAL_CONTEXT.with(|ctx| Arc::downgrade(ctx));
231         let schedule = move |runnable| {
232             if let Some(local_ctx) = weak_ctx.upgrade() {
233                 let waker = {
234                     let mut ctx = local_ctx.lock();
235                     ctx.queue.push_back(runnable);
236                     ctx.waker.take()
237                 };
238 
239                 if let Some(w) = waker {
240                     w.wake();
241                 }
242             }
243         };
244         let (runnable, task) = async_task::spawn_local(f, schedule);
245         runnable.schedule();
246         task
247     }
248 
249     /// Run the provided closure on a dedicated thread where blocking is allowed.
250     ///
251     /// Callers may `await` on the returned `Task` to wait for the result of `f`. Dropping or
252     /// canceling the returned `Task` may not cancel the operation if it was already started on a
253     /// worker thread.
254     ///
255     /// # Panics
256     ///
257     /// `await`ing the `Task` after the `Executor` is dropped will panic if the work was not already
258     /// completed.
259     ///
260     /// # Examples
261     ///
262     /// ```edition2018
263     /// # use cros_async::Executor;
264     /// #
265     /// # async fn do_it(ex: &Executor) {
266     ///     let res = ex.spawn_blocking(move || {
267     ///         // Do some CPU-intensive or blocking work here.
268     ///
269     ///         42
270     ///     }).await;
271     ///
272     ///     assert_eq!(res, 42);
273     /// # }
274     /// #
275     /// # let ex = Executor::new();
276     /// # ex.run_until(do_it(&ex)).unwrap();
277     /// ```
spawn_blocking<F, R>(&self, f: F) -> Task<R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,278     pub fn spawn_blocking<F, R>(&self, f: F) -> Task<R>
279     where
280         F: FnOnce() -> R + Send + 'static,
281         R: Send + 'static,
282     {
283         self.shared.lock().blocking_pool.spawn(f)
284     }
285 
286     /// Run the executor indefinitely, driving all spawned futures to completion. This method will
287     /// block the current thread and only return in the case of an error.
288     ///
289     /// # Examples
290     ///
291     /// ```
292     /// # use anyhow::Result;
293     /// # fn example_run() -> Result<()> {
294     ///       use std::thread;
295     ///
296     ///       use cros_async::Executor;
297     ///
298     ///       let ex = Executor::new();
299     ///
300     ///       // Spawn a thread that runs the executor.
301     ///       let ex2 = ex.clone();
302     ///       thread::spawn(move || ex2.run());
303     ///
304     ///       let task = ex.spawn(async { 7 + 13 });
305     ///
306     ///       let result = ex.run_until(task)?;
307     ///       assert_eq!(result, 20);
308     /// #     Ok(())
309     /// # }
310     /// #
311     /// # example_run().unwrap();
312     /// ```
313     #[inline]
run(&self) -> Result<()>314     pub fn run(&self) -> Result<()> {
315         self.run_until(pending())
316     }
317 
318     /// Drive all futures spawned in this executor until `f` completes. This method will block the
319     /// current thread only until `f` is complete and there may still be unfinished futures in the
320     /// executor.
321     ///
322     /// # Examples
323     ///
324     /// ```
325     /// # use anyhow::Result;
326     /// # fn example_run_until() -> Result<()> {
327     ///       use cros_async::Executor;
328     ///
329     ///       let ex = Executor::new();
330     ///
331     ///       let task = ex.spawn_local(async { 7 + 13 });
332     ///
333     ///       let result = ex.run_until(task)?;
334     ///       assert_eq!(result, 20);
335     /// #     Ok(())
336     /// # }
337     /// #
338     /// # example_run_until().unwrap();
339     /// ```
run_until<F: Future>(&self, done: F) -> Result<F::Output>340     pub fn run_until<F: Future>(&self, done: F) -> Result<F::Output> {
341         // Prevent nested execution.
342         let _guard = enter()?;
343 
344         pin_mut!(done);
345 
346         let current_thread = thread::current().id();
347         let state = sys::platform_state()?;
348         let waker = state.waker_ref();
349         let mut cx = task::Context::from_waker(&waker);
350         let mut done_polled = false;
351 
352         LOCAL_CONTEXT.with(|local_ctx| {
353             let next_local = || local_ctx.lock().queue.pop_front();
354             let next_global = || self.shared.lock().queue.pop_front();
355 
356             let mut tick = Wrapping(0u32);
357 
358             loop {
359                 tick += Wrapping(1);
360 
361                 // If there are always tasks available to run in either the local or the global
362                 // queue then we may go a long time without fetching completed events from the
363                 // underlying platform driver. Poll the driver once in a while to prevent this from
364                 // happening.
365                 if tick.0 % 31 == 0 {
366                     // A zero timeout will fetch new events without blocking.
367                     self.get_events(&state, Some(Duration::from_millis(0)))?;
368                 }
369 
370                 let was_woken = state.start_processing();
371                 if was_woken || !done_polled {
372                     done_polled = true;
373                     if let Poll::Ready(v) = done.as_mut().poll(&mut cx) {
374                         return Ok(v);
375                     }
376                 }
377 
378                 // If there are always tasks in the local queue then any tasks in the global queue
379                 // will get starved. Pull tasks out of the global queue every once in a while even
380                 // when there are still local tasks available to prevent this.
381                 let next_runnable = if tick.0 % 13 == 0 {
382                     next_global().or_else(next_local)
383                 } else {
384                     next_local().or_else(next_global)
385                 };
386 
387                 if let Some(runnable) = next_runnable {
388                     runnable.run();
389                     continue;
390                 }
391 
392                 // We're about to block so first check that new tasks have not snuck in and set the
393                 // waker so that we can be woken up when tasks are re-scheduled.
394                 let deadline = {
395                     let mut ctx = local_ctx.lock();
396                     if !ctx.queue.is_empty() {
397                         // Some more tasks managed to sneak in.  Go back to the start of the loop.
398                         continue;
399                     }
400 
401                     // There are no more tasks to run so set the waker.
402                     if ctx.waker.is_none() {
403                         ctx.waker = Some(cx.waker().clone());
404                     }
405 
406                     // TODO: Replace with `last_entry` once it is stabilized.
407                     ctx.timers.keys().next_back().cloned()
408                 };
409                 {
410                     let mut shared = self.shared.lock();
411                     if !shared.queue.is_empty() {
412                         // More tasks were added to the global queue. Go back to the start of the loop.
413                         continue;
414                     }
415 
416                     // We're going to block so add ourselves to the idle worker list.
417                     shared
418                         .idle_workers
419                         .push_back((current_thread, cx.waker().clone()));
420                 };
421 
422                 // Now wait to be woken up.
423                 let timeout = deadline.map(|d| d.0.saturating_duration_since(Instant::now()));
424                 self.get_events(&state, timeout)?;
425 
426                 // Remove from idle workers.
427                 {
428                     let mut shared = self.shared.lock();
429                     if let Some(idx) = shared
430                         .idle_workers
431                         .iter()
432                         .position(|(id, _)| id == &current_thread)
433                     {
434                         shared.idle_workers.swap_remove_back(idx);
435                     }
436                 }
437 
438                 // Reset the ticks since we just fetched new events from the platform driver.
439                 tick = Wrapping(0);
440             }
441         })
442     }
443 
get_events<S: PlatformState>( &self, state: &S, timeout: Option<Duration>, ) -> anyhow::Result<()>444     fn get_events<S: PlatformState>(
445         &self,
446         state: &S,
447         timeout: Option<Duration>,
448     ) -> anyhow::Result<()> {
449         state.wait(timeout)?;
450 
451         // Timer maintenance.
452         let expired = LOCAL_CONTEXT.with(|local_ctx| {
453             let mut ctx = local_ctx.lock();
454             let now = Instant::now();
455             ctx.timers.split_off(&Reverse(now))
456         });
457 
458         // We cannot wake the timers while holding the lock because the schedule function for the
459         // task that's waiting on the timer may try to acquire the lock.
460         for (deadline, wakers) in expired {
461             debug_assert!(deadline.0 <= Instant::now());
462             for w in wakers {
463                 w.wake();
464             }
465         }
466 
467         Ok(())
468     }
469 }
470 
471 // A trait that represents any thread-local platform-specific state that needs to be held on behalf
472 // of the `Executor`.
473 pub(crate) trait PlatformState {
474     // Indicates that the `Executor` is about to start processing futures that have been woken up.
475     //
476     // Implementations may use this as an indicator to skip unnecessary work when new tasks are
477     // woken up as the `Executor` will eventually get around to processing them on its own.
478     //
479     // `start_processing` must return true if one or more futures were woken up since the last call
480     // to `start_processing`. Otherwise it may return false.
start_processing(&self) -> bool481     fn start_processing(&self) -> bool;
482 
483     // Returns a `WakerRef` that can be used to wake up the current thread.
waker_ref(&self) -> WakerRef484     fn waker_ref(&self) -> WakerRef;
485 
486     // Waits for one or more futures to be woken up.
487     //
488     // This method should check with the underlying OS if any asynchronous IO operations have
489     // completed and then wake up the associated futures.
490     //
491     // If `timeout` is provided then this method should block until either one or more futures are
492     // woken up or the timeout duration elapses. If `timeout` has a zero duration then this method
493     // should fetch completed asynchronous IO operations and then immediately return.
494     //
495     // If `timeout` is not provided then this method should block until one or more futures are
496     // woken up.
wait(&self, timeout: Option<Duration>) -> anyhow::Result<()>497     fn wait(&self, timeout: Option<Duration>) -> anyhow::Result<()>;
498 }
499 
500 #[cfg(test)]
501 mod test {
502     use super::*;
503 
504     use std::{
505         convert::TryFrom,
506         fs::OpenOptions,
507         mem,
508         pin::Pin,
509         thread::{self, JoinHandle},
510         time::Instant,
511     };
512 
513     use futures::{
514         channel::{mpsc, oneshot},
515         future::{join3, select, Either},
516         sink::SinkExt,
517         stream::{self, FuturesUnordered, StreamExt},
518     };
519 
520     use crate::{File, OwnedIoBuf};
521 
522     #[test]
basic()523     fn basic() {
524         async fn do_it() {
525             let (r, _w) = sys_util::pipe(true).unwrap();
526             let done = async { 5usize };
527 
528             let rx = File::try_from(r).unwrap();
529             let mut buf = 0u64.to_ne_bytes();
530             let pending = rx.read(&mut buf, None);
531             pin_mut!(pending, done);
532 
533             match select(pending, done).await {
534                 Either::Right((5, pending)) => drop(pending),
535                 _ => panic!("unexpected select result"),
536             }
537         }
538 
539         Executor::new().run_until(do_it()).unwrap();
540     }
541 
542     #[derive(Default)]
543     struct QuitShared {
544         wakers: Vec<task::Waker>,
545         should_quit: bool,
546     }
547 
548     #[derive(Clone, Default)]
549     struct Quit {
550         shared: Arc<Mutex<QuitShared>>,
551     }
552 
553     impl Quit {
quit(self)554         fn quit(self) {
555             let wakers = {
556                 let mut shared = self.shared.lock();
557                 shared.should_quit = true;
558                 mem::take(&mut shared.wakers)
559             };
560 
561             for w in wakers {
562                 w.wake();
563             }
564         }
565     }
566 
567     impl Future for Quit {
568         type Output = ();
569 
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>570         fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
571             let mut shared = self.shared.lock();
572             if shared.should_quit {
573                 return Poll::Ready(());
574             }
575 
576             if shared.wakers.iter().all(|w| !w.will_wake(cx.waker())) {
577                 shared.wakers.push(cx.waker().clone());
578             }
579 
580             Poll::Pending
581         }
582     }
583 
584     #[test]
outer_future_is_send()585     fn outer_future_is_send() {
586         const NUM_THREADS: usize = 3;
587         const CHUNK_SIZE: usize = 32;
588 
589         async fn read_iobuf(
590             ex: &Executor,
591             f: File,
592             buf: OwnedIoBuf,
593         ) -> (anyhow::Result<usize>, OwnedIoBuf, File) {
594             let (tx, rx) = oneshot::channel();
595             ex.spawn_local(async move {
596                 let (res, buf) = f.read_iobuf(buf, None).await;
597                 let _ = tx.send((res, buf, f));
598             })
599             .detach();
600             rx.await.unwrap()
601         }
602 
603         async fn write_iobuf(
604             ex: &Executor,
605             f: File,
606             buf: OwnedIoBuf,
607         ) -> (anyhow::Result<usize>, OwnedIoBuf, File) {
608             let (tx, rx) = oneshot::channel();
609             ex.spawn_local(async move {
610                 let (res, buf) = f.write_iobuf(buf, None).await;
611                 let _ = tx.send((res, buf, f));
612             })
613             .detach();
614             rx.await.unwrap()
615         }
616 
617         async fn transfer_data(
618             ex: Executor,
619             mut from: File,
620             mut to: File,
621             len: usize,
622         ) -> Result<usize> {
623             let mut rem = len;
624             let mut buf = OwnedIoBuf::new(vec![0xa2u8; CHUNK_SIZE]);
625             while rem > 0 {
626                 let (res, data, f) = read_iobuf(&ex, from, buf).await;
627                 let count = res?;
628                 buf = data;
629                 from = f;
630                 if count == 0 {
631                     // End of file. Return the number of bytes transferred.
632                     return Ok(len - rem);
633                 }
634                 assert_eq!(count, CHUNK_SIZE);
635 
636                 let (res, data, t) = write_iobuf(&ex, to, buf).await;
637                 let count = res?;
638                 buf = data;
639                 to = t;
640                 assert_eq!(count, CHUNK_SIZE);
641 
642                 rem = rem.saturating_sub(count);
643             }
644 
645             Ok(len)
646         }
647 
648         fn do_it() -> anyhow::Result<()> {
649             let ex = Executor::new();
650             let (rx, tx) = sys_util::pipe(true)?;
651             let zero = File::open("/dev/zero")?;
652             let zero_bytes = CHUNK_SIZE * 7;
653             let zero_to_pipe = ex.spawn(transfer_data(
654                 ex.clone(),
655                 zero,
656                 File::try_from(tx.try_clone()?)?,
657                 zero_bytes,
658             ));
659 
660             let rand = File::open("/dev/urandom")?;
661             let rand_bytes = CHUNK_SIZE * 19;
662             let rand_to_pipe = ex.spawn(transfer_data(
663                 ex.clone(),
664                 rand,
665                 File::try_from(tx)?,
666                 rand_bytes,
667             ));
668 
669             let null = OpenOptions::new().write(true).open("/dev/null")?;
670             let null_bytes = zero_bytes + rand_bytes;
671             let pipe_to_null = ex.spawn(transfer_data(
672                 ex.clone(),
673                 File::try_from(rx)?,
674                 File::try_from(null)?,
675                 null_bytes,
676             ));
677 
678             let mut threads = Vec::with_capacity(NUM_THREADS);
679             let quit = Quit::default();
680             for _ in 0..NUM_THREADS {
681                 let thread_ex = ex.clone();
682                 let thread_quit = quit.clone();
683                 threads.push(thread::spawn(move || thread_ex.run_until(thread_quit)))
684             }
685             ex.run_until(join3(
686                 async { assert_eq!(pipe_to_null.await.unwrap(), null_bytes) },
687                 async { assert_eq!(zero_to_pipe.await.unwrap(), zero_bytes) },
688                 async { assert_eq!(rand_to_pipe.await.unwrap(), rand_bytes) },
689             ))?;
690 
691             quit.quit();
692             for t in threads {
693                 t.join().unwrap().unwrap();
694             }
695 
696             Ok(())
697         }
698 
699         do_it().unwrap();
700     }
701 
702     #[test]
thread_pool()703     fn thread_pool() {
704         const NUM_THREADS: usize = 8;
705         const NUM_CHANNELS: usize = 19;
706         const NUM_ITERATIONS: usize = 71;
707 
708         let ex = Executor::new();
709 
710         let tasks = FuturesUnordered::new();
711         let (mut tx, mut rx) = mpsc::channel(10);
712         tasks.push(ex.spawn(async move {
713             for i in 0..NUM_ITERATIONS {
714                 tx.send(i).await?;
715             }
716 
717             Ok::<(), anyhow::Error>(())
718         }));
719 
720         for _ in 0..NUM_CHANNELS {
721             let (mut task_tx, task_rx) = mpsc::channel(10);
722             tasks.push(ex.spawn(async move {
723                 while let Some(v) = rx.next().await {
724                     task_tx.send(v).await?;
725                 }
726 
727                 Ok::<(), anyhow::Error>(())
728             }));
729 
730             rx = task_rx;
731         }
732 
733         tasks.push(ex.spawn(async move {
734             let mut zip = rx.zip(stream::iter(0..NUM_ITERATIONS));
735             while let Some((l, r)) = zip.next().await {
736                 assert_eq!(l, r);
737             }
738 
739             Ok::<(), anyhow::Error>(())
740         }));
741 
742         let quit = Quit::default();
743         let mut threads = Vec::with_capacity(NUM_THREADS);
744         for _ in 0..NUM_THREADS {
745             let thread_ex = ex.clone();
746             let thread_quit = quit.clone();
747             threads.push(thread::spawn(move || thread_ex.run_until(thread_quit)));
748         }
749 
750         let results = ex
751             .run_until(tasks.collect::<Vec<anyhow::Result<()>>>())
752             .unwrap();
753         results
754             .into_iter()
755             .collect::<anyhow::Result<Vec<()>>>()
756             .unwrap();
757 
758         quit.quit();
759         for t in threads {
760             t.join().unwrap().unwrap();
761         }
762     }
763 
764     // Sends a message on `tx` once there is an idle worker in `Executor` or 5 seconds have passed.
765     // Sends true if this function observed an idle worker and false otherwise.
notify_on_idle_worker(ex: Executor, tx: oneshot::Sender<bool>)766     fn notify_on_idle_worker(ex: Executor, tx: oneshot::Sender<bool>) {
767         let deadline = Instant::now() + Duration::from_secs(5);
768         while Instant::now() < deadline {
769             // Wait for the main thread to add itself to the idle worker list.
770             if !ex.shared.lock().idle_workers.is_empty() {
771                 break;
772             }
773 
774             thread::sleep(Duration::from_millis(10));
775         }
776 
777         if Instant::now() <= deadline {
778             tx.send(true).unwrap();
779         } else {
780             tx.send(false).unwrap();
781         }
782     }
783 
784     #[test]
wakeup_run_until()785     fn wakeup_run_until() {
786         let (tx, rx) = oneshot::channel();
787 
788         let ex = Executor::new();
789 
790         let thread_ex = ex.clone();
791         let waker_thread = thread::spawn(move || notify_on_idle_worker(thread_ex, tx));
792 
793         // Since we're using `run_until` the wakeup path won't use the regular scheduling functions.
794         let success = ex.run_until(rx).unwrap().unwrap();
795         assert!(success);
796         assert!(ex.shared.lock().idle_workers.is_empty());
797 
798         waker_thread.join().unwrap();
799     }
800 
801     #[test]
wakeup_local_task()802     fn wakeup_local_task() {
803         let (tx, rx) = oneshot::channel();
804 
805         let ex = Executor::new();
806 
807         let thread_ex = ex.clone();
808         let waker_thread = thread::spawn(move || notify_on_idle_worker(thread_ex, tx));
809 
810         // By using `spawn_local`, the wakeup path will go via LOCAL_CTX.
811         let task = ex.spawn_local(rx);
812         let success = ex.run_until(task).unwrap().unwrap();
813         assert!(success);
814         assert!(ex.shared.lock().idle_workers.is_empty());
815 
816         waker_thread.join().unwrap();
817     }
818 
819     #[test]
wakeup_global_task()820     fn wakeup_global_task() {
821         let (tx, rx) = oneshot::channel();
822 
823         let ex = Executor::new();
824 
825         let thread_ex = ex.clone();
826         let waker_thread = thread::spawn(move || notify_on_idle_worker(thread_ex, tx));
827 
828         // By using `spawn`, the wakeup path will go via `ex.shared`.
829         let task = ex.spawn(rx);
830         let success = ex.run_until(task).unwrap().unwrap();
831         assert!(success);
832         assert!(ex.shared.lock().idle_workers.is_empty());
833 
834         waker_thread.join().unwrap();
835     }
836 
837     #[test]
wake_up_correct_worker()838     fn wake_up_correct_worker() {
839         struct ThreadData {
840             id: ThreadId,
841             sender: mpsc::Sender<()>,
842             handle: JoinHandle<anyhow::Result<()>>,
843         }
844 
845         const NUM_THREADS: usize = 7;
846         const NUM_ITERATIONS: usize = 119;
847 
848         let ex = Executor::new();
849 
850         let (tx, mut rx) = mpsc::channel(0);
851         let mut threads = Vec::with_capacity(NUM_THREADS);
852         for _ in 0..NUM_THREADS {
853             let (sender, mut receiver) = mpsc::channel(0);
854             let mut thread_tx = tx.clone();
855             let thread_ex = ex.clone();
856             let handle = thread::spawn(move || {
857                 let id = thread::current().id();
858                 thread_ex
859                     .run_until(async move {
860                         while let Some(()) = receiver.next().await {
861                             thread_tx.send(id).await?;
862                         }
863 
864                         Ok(())
865                     })
866                     .unwrap()
867             });
868 
869             let id = handle.thread().id();
870             threads.push(ThreadData { id, sender, handle });
871         }
872 
873         ex.run_until(async {
874             for i in 0..NUM_ITERATIONS {
875                 let data = &mut threads[i % NUM_THREADS];
876                 data.sender.send(()).await?;
877                 assert_eq!(rx.next().await.unwrap(), data.id);
878             }
879 
880             Ok::<(), anyhow::Error>(())
881         })
882         .unwrap()
883         .unwrap();
884 
885         for t in threads {
886             let ThreadData { id, sender, handle } = t;
887 
888             // Dropping the sender will close the channel and cause the thread to exit.
889             drop((id, sender));
890             handle.join().unwrap().unwrap();
891         }
892     }
893 }
894