• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Runs `!Send` futures on the current thread.
2 use crate::loom::cell::UnsafeCell;
3 use crate::loom::sync::{Arc, Mutex};
4 #[cfg(tokio_unstable)]
5 use crate::runtime;
6 use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task};
7 use crate::runtime::{context, ThreadId};
8 use crate::sync::AtomicWaker;
9 use crate::util::RcCell;
10 
11 use std::cell::Cell;
12 use std::collections::VecDeque;
13 use std::fmt;
14 use std::future::Future;
15 use std::marker::PhantomData;
16 use std::pin::Pin;
17 use std::rc::Rc;
18 use std::task::Poll;
19 
20 use pin_project_lite::pin_project;
21 
22 cfg_rt! {
23     /// A set of tasks which are executed on the same thread.
24     ///
25     /// In some cases, it is necessary to run one or more futures that do not
26     /// implement [`Send`] and thus are unsafe to send between threads. In these
27     /// cases, a [local task set] may be used to schedule one or more `!Send`
28     /// futures to run together on the same thread.
29     ///
30     /// For example, the following code will not compile:
31     ///
32     /// ```rust,compile_fail
33     /// use std::rc::Rc;
34     ///
35     /// #[tokio::main]
36     /// async fn main() {
37     ///     // `Rc` does not implement `Send`, and thus may not be sent between
38     ///     // threads safely.
39     ///     let nonsend_data = Rc::new("my nonsend data...");
40     ///
41     ///     let nonsend_data = nonsend_data.clone();
42     ///     // Because the `async` block here moves `nonsend_data`, the future is `!Send`.
43     ///     // Since `tokio::spawn` requires the spawned future to implement `Send`, this
44     ///     // will not compile.
45     ///     tokio::spawn(async move {
46     ///         println!("{}", nonsend_data);
47     ///         // ...
48     ///     }).await.unwrap();
49     /// }
50     /// ```
51     ///
52     /// # Use with `run_until`
53     ///
54     /// To spawn `!Send` futures, we can use a local task set to schedule them
55     /// on the thread calling [`Runtime::block_on`]. When running inside of the
56     /// local task set, we can use [`task::spawn_local`], which can spawn
57     /// `!Send` futures. For example:
58     ///
59     /// ```rust
60     /// use std::rc::Rc;
61     /// use tokio::task;
62     ///
63     /// #[tokio::main]
64     /// async fn main() {
65     ///     let nonsend_data = Rc::new("my nonsend data...");
66     ///
67     ///     // Construct a local task set that can run `!Send` futures.
68     ///     let local = task::LocalSet::new();
69     ///
70     ///     // Run the local task set.
71     ///     local.run_until(async move {
72     ///         let nonsend_data = nonsend_data.clone();
73     ///         // `spawn_local` ensures that the future is spawned on the local
74     ///         // task set.
75     ///         task::spawn_local(async move {
76     ///             println!("{}", nonsend_data);
77     ///             // ...
78     ///         }).await.unwrap();
79     ///     }).await;
80     /// }
81     /// ```
82     /// **Note:** The `run_until` method can only be used in `#[tokio::main]`,
83     /// `#[tokio::test]` or directly inside a call to [`Runtime::block_on`]. It
84     /// cannot be used inside a task spawned with `tokio::spawn`.
85     ///
86     /// ## Awaiting a `LocalSet`
87     ///
88     /// Additionally, a `LocalSet` itself implements `Future`, completing when
89     /// *all* tasks spawned on the `LocalSet` complete. This can be used to run
90     /// several futures on a `LocalSet` and drive the whole set until they
91     /// complete. For example,
92     ///
93     /// ```rust
94     /// use tokio::{task, time};
95     /// use std::rc::Rc;
96     ///
97     /// #[tokio::main]
98     /// async fn main() {
99     ///     let nonsend_data = Rc::new("world");
100     ///     let local = task::LocalSet::new();
101     ///
102     ///     let nonsend_data2 = nonsend_data.clone();
103     ///     local.spawn_local(async move {
104     ///         // ...
105     ///         println!("hello {}", nonsend_data2)
106     ///     });
107     ///
108     ///     local.spawn_local(async move {
109     ///         time::sleep(time::Duration::from_millis(100)).await;
110     ///         println!("goodbye {}", nonsend_data)
111     ///     });
112     ///
113     ///     // ...
114     ///
115     ///     local.await;
116     /// }
117     /// ```
118     /// **Note:** Awaiting a `LocalSet` can only be done inside
119     /// `#[tokio::main]`, `#[tokio::test]` or directly inside a call to
120     /// [`Runtime::block_on`]. It cannot be used inside a task spawned with
121     /// `tokio::spawn`.
122     ///
123     /// ## Use inside `tokio::spawn`
124     ///
125     /// The two methods mentioned above cannot be used inside `tokio::spawn`, so
126     /// to spawn `!Send` futures from inside `tokio::spawn`, we need to do
127     /// something else. The solution is to create the `LocalSet` somewhere else,
128     /// and communicate with it using an [`mpsc`] channel.
129     ///
130     /// The following example puts the `LocalSet` inside a new thread.
131     /// ```
132     /// use tokio::runtime::Builder;
133     /// use tokio::sync::{mpsc, oneshot};
134     /// use tokio::task::LocalSet;
135     ///
136     /// // This struct describes the task you want to spawn. Here we include
137     /// // some simple examples. The oneshot channel allows sending a response
138     /// // to the spawner.
139     /// #[derive(Debug)]
140     /// enum Task {
141     ///     PrintNumber(u32),
142     ///     AddOne(u32, oneshot::Sender<u32>),
143     /// }
144     ///
145     /// #[derive(Clone)]
146     /// struct LocalSpawner {
147     ///    send: mpsc::UnboundedSender<Task>,
148     /// }
149     ///
150     /// impl LocalSpawner {
151     ///     pub fn new() -> Self {
152     ///         let (send, mut recv) = mpsc::unbounded_channel();
153     ///
154     ///         let rt = Builder::new_current_thread()
155     ///             .enable_all()
156     ///             .build()
157     ///             .unwrap();
158     ///
159     ///         std::thread::spawn(move || {
160     ///             let local = LocalSet::new();
161     ///
162     ///             local.spawn_local(async move {
163     ///                 while let Some(new_task) = recv.recv().await {
164     ///                     tokio::task::spawn_local(run_task(new_task));
165     ///                 }
166     ///                 // If the while loop returns, then all the LocalSpawner
167     ///                 // objects have been dropped.
168     ///             });
169     ///
170     ///             // This will return once all senders are dropped and all
171     ///             // spawned tasks have returned.
172     ///             rt.block_on(local);
173     ///         });
174     ///
175     ///         Self {
176     ///             send,
177     ///         }
178     ///     }
179     ///
180     ///     pub fn spawn(&self, task: Task) {
181     ///         self.send.send(task).expect("Thread with LocalSet has shut down.");
182     ///     }
183     /// }
184     ///
185     /// // This task may do !Send stuff. We use printing a number as an example,
186     /// // but it could be anything.
187     /// //
188     /// // The Task struct is an enum to support spawning many different kinds
189     /// // of operations.
190     /// async fn run_task(task: Task) {
191     ///     match task {
192     ///         Task::PrintNumber(n) => {
193     ///             println!("{}", n);
194     ///         },
195     ///         Task::AddOne(n, response) => {
196     ///             // We ignore failures to send the response.
197     ///             let _ = response.send(n + 1);
198     ///         },
199     ///     }
200     /// }
201     ///
202     /// #[tokio::main]
203     /// async fn main() {
204     ///     let spawner = LocalSpawner::new();
205     ///
206     ///     let (send, response) = oneshot::channel();
207     ///     spawner.spawn(Task::AddOne(10, send));
208     ///     let eleven = response.await.unwrap();
209     ///     assert_eq!(eleven, 11);
210     /// }
211     /// ```
212     ///
213     /// [`Send`]: trait@std::marker::Send
214     /// [local task set]: struct@LocalSet
215     /// [`Runtime::block_on`]: method@crate::runtime::Runtime::block_on
216     /// [`task::spawn_local`]: fn@spawn_local
217     /// [`mpsc`]: mod@crate::sync::mpsc
218     pub struct LocalSet {
219         /// Current scheduler tick.
220         tick: Cell<u8>,
221 
222         /// State available from thread-local.
223         context: Rc<Context>,
224 
225         /// This type should not be Send.
226         _not_send: PhantomData<*const ()>,
227     }
228 }
229 
230 /// State available from the thread-local.
231 struct Context {
232     /// State shared between threads.
233     shared: Arc<Shared>,
234 
235     /// True if a task panicked without being handled and the local set is
236     /// configured to shutdown on unhandled panic.
237     unhandled_panic: Cell<bool>,
238 }
239 
240 /// LocalSet state shared between threads.
241 struct Shared {
242     /// # Safety
243     ///
244     /// This field must *only* be accessed from the thread that owns the
245     /// `LocalSet` (i.e., `Thread::current().id() == owner`).
246     local_state: LocalState,
247 
248     /// Remote run queue sender.
249     queue: Mutex<Option<VecDeque<task::Notified<Arc<Shared>>>>>,
250 
251     /// Wake the `LocalSet` task.
252     waker: AtomicWaker,
253 
254     /// How to respond to unhandled task panics.
255     #[cfg(tokio_unstable)]
256     pub(crate) unhandled_panic: crate::runtime::UnhandledPanic,
257 }
258 
259 /// Tracks the `LocalSet` state that must only be accessed from the thread that
260 /// created the `LocalSet`.
261 struct LocalState {
262     /// The `ThreadId` of the thread that owns the `LocalSet`.
263     owner: ThreadId,
264 
265     /// Local run queue sender and receiver.
266     local_queue: UnsafeCell<VecDeque<task::Notified<Arc<Shared>>>>,
267 
268     /// Collection of all active tasks spawned onto this executor.
269     owned: LocalOwnedTasks<Arc<Shared>>,
270 }
271 
272 pin_project! {
273     #[derive(Debug)]
274     struct RunUntil<'a, F> {
275         local_set: &'a LocalSet,
276         #[pin]
277         future: F,
278     }
279 }
280 
281 tokio_thread_local!(static CURRENT: LocalData = const { LocalData {
282     ctx: RcCell::new(),
283 } });
284 
285 struct LocalData {
286     ctx: RcCell<Context>,
287 }
288 
289 cfg_rt! {
290     /// Spawns a `!Send` future on the current [`LocalSet`].
291     ///
292     /// The spawned future will run on the same thread that called `spawn_local`.
293     ///
294     /// The provided future will start running in the background immediately
295     /// when `spawn_local` is called, even if you don't await the returned
296     /// `JoinHandle`.
297     ///
298     /// # Panics
299     ///
300     /// This function panics if called outside of a [`LocalSet`].
301     ///
302     /// Note that if [`tokio::spawn`] is used from within a `LocalSet`, the
303     /// resulting new task will _not_ be inside the `LocalSet`, so you must use
304     /// `spawn_local` if you want to stay within the `LocalSet`.
305     ///
306     /// # Examples
307     ///
308     /// ```rust
309     /// use std::rc::Rc;
310     /// use tokio::task;
311     ///
312     /// #[tokio::main]
313     /// async fn main() {
314     ///     let nonsend_data = Rc::new("my nonsend data...");
315     ///
316     ///     let local = task::LocalSet::new();
317     ///
318     ///     // Run the local task set.
319     ///     local.run_until(async move {
320     ///         let nonsend_data = nonsend_data.clone();
321     ///         task::spawn_local(async move {
322     ///             println!("{}", nonsend_data);
323     ///             // ...
324     ///         }).await.unwrap();
325     ///     }).await;
326     /// }
327     /// ```
328     ///
329     /// [`LocalSet`]: struct@crate::task::LocalSet
330     /// [`tokio::spawn`]: fn@crate::task::spawn
331     #[track_caller]
332     pub fn spawn_local<F>(future: F) -> JoinHandle<F::Output>
333     where
334         F: Future + 'static,
335         F::Output: 'static,
336     {
337         spawn_local_inner(future, None)
338     }
339 
340 
341     #[track_caller]
342     pub(super) fn spawn_local_inner<F>(future: F, name: Option<&str>) -> JoinHandle<F::Output>
343     where F: Future + 'static,
344           F::Output: 'static
345     {
346         match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) {
347             None => panic!("`spawn_local` called from outside of a `task::LocalSet`"),
348             Some(cx) => cx.spawn(future, name)
349        }
350     }
351 }
352 
353 /// Initial queue capacity.
354 const INITIAL_CAPACITY: usize = 64;
355 
356 /// Max number of tasks to poll per tick.
357 const MAX_TASKS_PER_TICK: usize = 61;
358 
359 /// How often it check the remote queue first.
360 const REMOTE_FIRST_INTERVAL: u8 = 31;
361 
362 /// Context guard for LocalSet
363 pub struct LocalEnterGuard(Option<Rc<Context>>);
364 
365 impl Drop for LocalEnterGuard {
drop(&mut self)366     fn drop(&mut self) {
367         CURRENT.with(|LocalData { ctx, .. }| {
368             ctx.set(self.0.take());
369         })
370     }
371 }
372 
373 impl fmt::Debug for LocalEnterGuard {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result374     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
375         f.debug_struct("LocalEnterGuard").finish()
376     }
377 }
378 
379 impl LocalSet {
380     /// Returns a new local task set.
new() -> LocalSet381     pub fn new() -> LocalSet {
382         let owner = context::thread_id().expect("cannot create LocalSet during thread shutdown");
383 
384         LocalSet {
385             tick: Cell::new(0),
386             context: Rc::new(Context {
387                 shared: Arc::new(Shared {
388                     local_state: LocalState {
389                         owner,
390                         owned: LocalOwnedTasks::new(),
391                         local_queue: UnsafeCell::new(VecDeque::with_capacity(INITIAL_CAPACITY)),
392                     },
393                     queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))),
394                     waker: AtomicWaker::new(),
395                     #[cfg(tokio_unstable)]
396                     unhandled_panic: crate::runtime::UnhandledPanic::Ignore,
397                 }),
398                 unhandled_panic: Cell::new(false),
399             }),
400             _not_send: PhantomData,
401         }
402     }
403 
404     /// Enters the context of this `LocalSet`.
405     ///
406     /// The [`spawn_local`] method will spawn tasks on the `LocalSet` whose
407     /// context you are inside.
408     ///
409     /// [`spawn_local`]: fn@crate::task::spawn_local
enter(&self) -> LocalEnterGuard410     pub fn enter(&self) -> LocalEnterGuard {
411         CURRENT.with(|LocalData { ctx, .. }| {
412             let old = ctx.replace(Some(self.context.clone()));
413             LocalEnterGuard(old)
414         })
415     }
416 
417     /// Spawns a `!Send` task onto the local task set.
418     ///
419     /// This task is guaranteed to be run on the current thread.
420     ///
421     /// Unlike the free function [`spawn_local`], this method may be used to
422     /// spawn local tasks when the `LocalSet` is _not_ running. The provided
423     /// future will start running once the `LocalSet` is next started, even if
424     /// you don't await the returned `JoinHandle`.
425     ///
426     /// # Examples
427     ///
428     /// ```rust
429     /// use tokio::task;
430     ///
431     /// #[tokio::main]
432     /// async fn main() {
433     ///     let local = task::LocalSet::new();
434     ///
435     ///     // Spawn a future on the local set. This future will be run when
436     ///     // we call `run_until` to drive the task set.
437     ///     local.spawn_local(async {
438     ///        // ...
439     ///     });
440     ///
441     ///     // Run the local task set.
442     ///     local.run_until(async move {
443     ///         // ...
444     ///     }).await;
445     ///
446     ///     // When `run` finishes, we can spawn _more_ futures, which will
447     ///     // run in subsequent calls to `run_until`.
448     ///     local.spawn_local(async {
449     ///        // ...
450     ///     });
451     ///
452     ///     local.run_until(async move {
453     ///         // ...
454     ///     }).await;
455     /// }
456     /// ```
457     /// [`spawn_local`]: fn@spawn_local
458     #[track_caller]
spawn_local<F>(&self, future: F) -> JoinHandle<F::Output> where F: Future + 'static, F::Output: 'static,459     pub fn spawn_local<F>(&self, future: F) -> JoinHandle<F::Output>
460     where
461         F: Future + 'static,
462         F::Output: 'static,
463     {
464         self.spawn_named(future, None)
465     }
466 
467     /// Runs a future to completion on the provided runtime, driving any local
468     /// futures spawned on this task set on the current thread.
469     ///
470     /// This runs the given future on the runtime, blocking until it is
471     /// complete, and yielding its resolved result. Any tasks or timers which
472     /// the future spawns internally will be executed on the runtime. The future
473     /// may also call [`spawn_local`] to spawn_local additional local futures on the
474     /// current thread.
475     ///
476     /// This method should not be called from an asynchronous context.
477     ///
478     /// # Panics
479     ///
480     /// This function panics if the executor is at capacity, if the provided
481     /// future panics, or if called within an asynchronous execution context.
482     ///
483     /// # Notes
484     ///
485     /// Since this function internally calls [`Runtime::block_on`], and drives
486     /// futures in the local task set inside that call to `block_on`, the local
487     /// futures may not use [in-place blocking]. If a blocking call needs to be
488     /// issued from a local task, the [`spawn_blocking`] API may be used instead.
489     ///
490     /// For example, this will panic:
491     /// ```should_panic
492     /// use tokio::runtime::Runtime;
493     /// use tokio::task;
494     ///
495     /// let rt  = Runtime::new().unwrap();
496     /// let local = task::LocalSet::new();
497     /// local.block_on(&rt, async {
498     ///     let join = task::spawn_local(async {
499     ///         let blocking_result = task::block_in_place(|| {
500     ///             // ...
501     ///         });
502     ///         // ...
503     ///     });
504     ///     join.await.unwrap();
505     /// })
506     /// ```
507     /// This, however, will not panic:
508     /// ```
509     /// use tokio::runtime::Runtime;
510     /// use tokio::task;
511     ///
512     /// let rt  = Runtime::new().unwrap();
513     /// let local = task::LocalSet::new();
514     /// local.block_on(&rt, async {
515     ///     let join = task::spawn_local(async {
516     ///         let blocking_result = task::spawn_blocking(|| {
517     ///             // ...
518     ///         }).await;
519     ///         // ...
520     ///     });
521     ///     join.await.unwrap();
522     /// })
523     /// ```
524     ///
525     /// [`spawn_local`]: fn@spawn_local
526     /// [`Runtime::block_on`]: method@crate::runtime::Runtime::block_on
527     /// [in-place blocking]: fn@crate::task::block_in_place
528     /// [`spawn_blocking`]: fn@crate::task::spawn_blocking
529     #[track_caller]
530     #[cfg(feature = "rt")]
531     #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
block_on<F>(&self, rt: &crate::runtime::Runtime, future: F) -> F::Output where F: Future,532     pub fn block_on<F>(&self, rt: &crate::runtime::Runtime, future: F) -> F::Output
533     where
534         F: Future,
535     {
536         rt.block_on(self.run_until(future))
537     }
538 
539     /// Runs a future to completion on the local set, returning its output.
540     ///
541     /// This returns a future that runs the given future with a local set,
542     /// allowing it to call [`spawn_local`] to spawn additional `!Send` futures.
543     /// Any local futures spawned on the local set will be driven in the
544     /// background until the future passed to `run_until` completes. When the future
545     /// passed to `run` finishes, any local futures which have not completed
546     /// will remain on the local set, and will be driven on subsequent calls to
547     /// `run_until` or when [awaiting the local set] itself.
548     ///
549     /// # Examples
550     ///
551     /// ```rust
552     /// use tokio::task;
553     ///
554     /// #[tokio::main]
555     /// async fn main() {
556     ///     task::LocalSet::new().run_until(async {
557     ///         task::spawn_local(async move {
558     ///             // ...
559     ///         }).await.unwrap();
560     ///         // ...
561     ///     }).await;
562     /// }
563     /// ```
564     ///
565     /// [`spawn_local`]: fn@spawn_local
566     /// [awaiting the local set]: #awaiting-a-localset
run_until<F>(&self, future: F) -> F::Output where F: Future,567     pub async fn run_until<F>(&self, future: F) -> F::Output
568     where
569         F: Future,
570     {
571         let run_until = RunUntil {
572             future,
573             local_set: self,
574         };
575         run_until.await
576     }
577 
spawn_named<F>( &self, future: F, name: Option<&str>, ) -> JoinHandle<F::Output> where F: Future + 'static, F::Output: 'static,578     pub(in crate::task) fn spawn_named<F>(
579         &self,
580         future: F,
581         name: Option<&str>,
582     ) -> JoinHandle<F::Output>
583     where
584         F: Future + 'static,
585         F::Output: 'static,
586     {
587         let handle = self.context.spawn(future, name);
588 
589         // Because a task was spawned from *outside* the `LocalSet`, wake the
590         // `LocalSet` future to execute the new task, if it hasn't been woken.
591         //
592         // Spawning via the free fn `spawn` does not require this, as it can
593         // only be called from *within* a future executing on the `LocalSet` —
594         // in that case, the `LocalSet` must already be awake.
595         self.context.shared.waker.wake();
596         handle
597     }
598 
599     /// Ticks the scheduler, returning whether the local future needs to be
600     /// notified again.
tick(&self) -> bool601     fn tick(&self) -> bool {
602         for _ in 0..MAX_TASKS_PER_TICK {
603             // Make sure we didn't hit an unhandled panic
604             if self.context.unhandled_panic.get() {
605                 panic!("a spawned task panicked and the LocalSet is configured to shutdown on unhandled panic");
606             }
607 
608             match self.next_task() {
609                 // Run the task
610                 //
611                 // Safety: As spawned tasks are `!Send`, `run_unchecked` must be
612                 // used. We are responsible for maintaining the invariant that
613                 // `run_unchecked` is only called on threads that spawned the
614                 // task initially. Because `LocalSet` itself is `!Send`, and
615                 // `spawn_local` spawns into the `LocalSet` on the current
616                 // thread, the invariant is maintained.
617                 Some(task) => crate::runtime::coop::budget(|| task.run()),
618                 // We have fully drained the queue of notified tasks, so the
619                 // local future doesn't need to be notified again — it can wait
620                 // until something else wakes a task in the local set.
621                 None => return false,
622             }
623         }
624 
625         true
626     }
627 
next_task(&self) -> Option<task::LocalNotified<Arc<Shared>>>628     fn next_task(&self) -> Option<task::LocalNotified<Arc<Shared>>> {
629         let tick = self.tick.get();
630         self.tick.set(tick.wrapping_add(1));
631 
632         let task = if tick % REMOTE_FIRST_INTERVAL == 0 {
633             self.context
634                 .shared
635                 .queue
636                 .lock()
637                 .as_mut()
638                 .and_then(|queue| queue.pop_front())
639                 .or_else(|| self.pop_local())
640         } else {
641             self.pop_local().or_else(|| {
642                 self.context
643                     .shared
644                     .queue
645                     .lock()
646                     .as_mut()
647                     .and_then(|queue| queue.pop_front())
648             })
649         };
650 
651         task.map(|task| unsafe {
652             // Safety: because the `LocalSet` itself is `!Send`, we know we are
653             // on the same thread if we have access to the `LocalSet`, and can
654             // therefore access the local run queue.
655             self.context.shared.local_state.assert_owner(task)
656         })
657     }
658 
pop_local(&self) -> Option<task::Notified<Arc<Shared>>>659     fn pop_local(&self) -> Option<task::Notified<Arc<Shared>>> {
660         unsafe {
661             // Safety: because the `LocalSet` itself is `!Send`, we know we are
662             // on the same thread if we have access to the `LocalSet`, and can
663             // therefore access the local run queue.
664             self.context.shared.local_state.task_pop_front()
665         }
666     }
667 
with<T>(&self, f: impl FnOnce() -> T) -> T668     fn with<T>(&self, f: impl FnOnce() -> T) -> T {
669         CURRENT.with(|LocalData { ctx, .. }| {
670             struct Reset<'a> {
671                 ctx_ref: &'a RcCell<Context>,
672                 val: Option<Rc<Context>>,
673             }
674             impl<'a> Drop for Reset<'a> {
675                 fn drop(&mut self) {
676                     self.ctx_ref.set(self.val.take());
677                 }
678             }
679             let old = ctx.replace(Some(self.context.clone()));
680 
681             let _reset = Reset {
682                 ctx_ref: ctx,
683                 val: old,
684             };
685 
686             f()
687         })
688     }
689 
690     /// This method is like `with`, but it just calls `f` without setting the thread-local if that
691     /// fails.
with_if_possible<T>(&self, f: impl FnOnce() -> T) -> T692     fn with_if_possible<T>(&self, f: impl FnOnce() -> T) -> T {
693         let mut f = Some(f);
694 
695         let res = CURRENT.try_with(|LocalData { ctx, .. }| {
696             struct Reset<'a> {
697                 ctx_ref: &'a RcCell<Context>,
698                 val: Option<Rc<Context>>,
699             }
700             impl<'a> Drop for Reset<'a> {
701                 fn drop(&mut self) {
702                     self.ctx_ref.replace(self.val.take());
703                 }
704             }
705             let old = ctx.replace(Some(self.context.clone()));
706 
707             let _reset = Reset {
708                 ctx_ref: ctx,
709                 val: old,
710             };
711 
712             (f.take().unwrap())()
713         });
714 
715         match res {
716             Ok(res) => res,
717             Err(_access_error) => (f.take().unwrap())(),
718         }
719     }
720 }
721 
722 cfg_unstable! {
723     impl LocalSet {
724         /// Configure how the `LocalSet` responds to an unhandled panic on a
725         /// spawned task.
726         ///
727         /// By default, an unhandled panic (i.e. a panic not caught by
728         /// [`std::panic::catch_unwind`]) has no impact on the `LocalSet`'s
729         /// execution. The panic is error value is forwarded to the task's
730         /// [`JoinHandle`] and all other spawned tasks continue running.
731         ///
732         /// The `unhandled_panic` option enables configuring this behavior.
733         ///
734         /// * `UnhandledPanic::Ignore` is the default behavior. Panics on
735         ///   spawned tasks have no impact on the `LocalSet`'s execution.
736         /// * `UnhandledPanic::ShutdownRuntime` will force the `LocalSet` to
737         ///   shutdown immediately when a spawned task panics even if that
738         ///   task's `JoinHandle` has not been dropped. All other spawned tasks
739         ///   will immediately terminate and further calls to
740         ///   [`LocalSet::block_on`] and [`LocalSet::run_until`] will panic.
741         ///
742         /// # Panics
743         ///
744         /// This method panics if called after the `LocalSet` has started
745         /// running.
746         ///
747         /// # Unstable
748         ///
749         /// This option is currently unstable and its implementation is
750         /// incomplete. The API may change or be removed in the future. See
751         /// tokio-rs/tokio#4516 for more details.
752         ///
753         /// # Examples
754         ///
755         /// The following demonstrates a `LocalSet` configured to shutdown on
756         /// panic. The first spawned task panics and results in the `LocalSet`
757         /// shutting down. The second spawned task never has a chance to
758         /// execute. The call to `run_until` will panic due to the runtime being
759         /// forcibly shutdown.
760         ///
761         /// ```should_panic
762         /// use tokio::runtime::UnhandledPanic;
763         ///
764         /// # #[tokio::main]
765         /// # async fn main() {
766         /// tokio::task::LocalSet::new()
767         ///     .unhandled_panic(UnhandledPanic::ShutdownRuntime)
768         ///     .run_until(async {
769         ///         tokio::task::spawn_local(async { panic!("boom"); });
770         ///         tokio::task::spawn_local(async {
771         ///             // This task never completes
772         ///         });
773         ///
774         ///         // Do some work, but `run_until` will panic before it completes
775         /// # loop { tokio::task::yield_now().await; }
776         ///     })
777         ///     .await;
778         /// # }
779         /// ```
780         ///
781         /// [`JoinHandle`]: struct@crate::task::JoinHandle
782         pub fn unhandled_panic(&mut self, behavior: crate::runtime::UnhandledPanic) -> &mut Self {
783             // TODO: This should be set as a builder
784             Rc::get_mut(&mut self.context)
785                 .and_then(|ctx| Arc::get_mut(&mut ctx.shared))
786                 .expect("Unhandled Panic behavior modified after starting LocalSet")
787                 .unhandled_panic = behavior;
788             self
789         }
790 
791         /// Returns the [`Id`] of the current `LocalSet` runtime.
792         ///
793         /// # Examples
794         ///
795         /// ```rust
796         /// use tokio::task;
797         ///
798         /// #[tokio::main]
799         /// async fn main() {
800         ///     let local_set = task::LocalSet::new();
801         ///     println!("Local set id: {}", local_set.id());
802         /// }
803         /// ```
804         ///
805         /// **Note**: This is an [unstable API][unstable]. The public API of this type
806         /// may break in 1.x releases. See [the documentation on unstable
807         /// features][unstable] for details.
808         ///
809         /// [unstable]: crate#unstable-features
810         /// [`Id`]: struct@crate::runtime::Id
811         pub fn id(&self) -> runtime::Id {
812             self.context.shared.local_state.owned.id.into()
813         }
814     }
815 }
816 
817 impl fmt::Debug for LocalSet {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result818     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
819         fmt.debug_struct("LocalSet").finish()
820     }
821 }
822 
823 impl Future for LocalSet {
824     type Output = ();
825 
poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output>826     fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
827         // Register the waker before starting to work
828         self.context.shared.waker.register_by_ref(cx.waker());
829 
830         if self.with(|| self.tick()) {
831             // If `tick` returns true, we need to notify the local future again:
832             // there are still tasks remaining in the run queue.
833             cx.waker().wake_by_ref();
834             Poll::Pending
835 
836         // Safety: called from the thread that owns `LocalSet`. Because
837         // `LocalSet` is `!Send`, this is safe.
838         } else if unsafe { self.context.shared.local_state.owned_is_empty() } {
839             // If the scheduler has no remaining futures, we're done!
840             Poll::Ready(())
841         } else {
842             // There are still futures in the local set, but we've polled all the
843             // futures in the run queue. Therefore, we can just return Pending
844             // since the remaining futures will be woken from somewhere else.
845             Poll::Pending
846         }
847     }
848 }
849 
850 impl Default for LocalSet {
default() -> LocalSet851     fn default() -> LocalSet {
852         LocalSet::new()
853     }
854 }
855 
856 impl Drop for LocalSet {
drop(&mut self)857     fn drop(&mut self) {
858         self.with_if_possible(|| {
859             // Shut down all tasks in the LocalOwnedTasks and close it to
860             // prevent new tasks from ever being added.
861             unsafe {
862                 // Safety: called from the thread that owns `LocalSet`
863                 self.context.shared.local_state.close_and_shutdown_all();
864             }
865 
866             // We already called shutdown on all tasks above, so there is no
867             // need to call shutdown.
868 
869             // Safety: note that this *intentionally* bypasses the unsafe
870             // `Shared::local_queue()` method. This is in order to avoid the
871             // debug assertion that we are on the thread that owns the
872             // `LocalSet`, because on some systems (e.g. at least some macOS
873             // versions), attempting to get the current thread ID can panic due
874             // to the thread's local data that stores the thread ID being
875             // dropped *before* the `LocalSet`.
876             //
877             // Despite avoiding the assertion here, it is safe for us to access
878             // the local queue in `Drop`, because the `LocalSet` itself is
879             // `!Send`, so we can reasonably guarantee that it will not be
880             // `Drop`ped from another thread.
881             let local_queue = unsafe {
882                 // Safety: called from the thread that owns `LocalSet`
883                 self.context.shared.local_state.take_local_queue()
884             };
885             for task in local_queue {
886                 drop(task);
887             }
888 
889             // Take the queue from the Shared object to prevent pushing
890             // notifications to it in the future.
891             let queue = self.context.shared.queue.lock().take().unwrap();
892             for task in queue {
893                 drop(task);
894             }
895 
896             // Safety: called from the thread that owns `LocalSet`
897             assert!(unsafe { self.context.shared.local_state.owned_is_empty() });
898         });
899     }
900 }
901 
902 // === impl Context ===
903 
904 impl Context {
905     #[track_caller]
spawn<F>(&self, future: F, name: Option<&str>) -> JoinHandle<F::Output> where F: Future + 'static, F::Output: 'static,906     fn spawn<F>(&self, future: F, name: Option<&str>) -> JoinHandle<F::Output>
907     where
908         F: Future + 'static,
909         F::Output: 'static,
910     {
911         let id = crate::runtime::task::Id::next();
912         let future = crate::util::trace::task(future, "local", name, id.as_u64());
913 
914         // Safety: called from the thread that owns the `LocalSet`
915         let (handle, notified) = {
916             self.shared.local_state.assert_called_from_owner_thread();
917             self.shared
918                 .local_state
919                 .owned
920                 .bind(future, self.shared.clone(), id)
921         };
922 
923         if let Some(notified) = notified {
924             self.shared.schedule(notified);
925         }
926 
927         handle
928     }
929 }
930 
931 // === impl LocalFuture ===
932 
933 impl<T: Future> Future for RunUntil<'_, T> {
934     type Output = T::Output;
935 
poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output>936     fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
937         let me = self.project();
938 
939         me.local_set.with(|| {
940             me.local_set
941                 .context
942                 .shared
943                 .waker
944                 .register_by_ref(cx.waker());
945 
946             let _no_blocking = crate::runtime::context::disallow_block_in_place();
947             let f = me.future;
948 
949             if let Poll::Ready(output) = f.poll(cx) {
950                 return Poll::Ready(output);
951             }
952 
953             if me.local_set.tick() {
954                 // If `tick` returns `true`, we need to notify the local future again:
955                 // there are still tasks remaining in the run queue.
956                 cx.waker().wake_by_ref();
957             }
958 
959             Poll::Pending
960         })
961     }
962 }
963 
964 impl Shared {
965     /// Schedule the provided task on the scheduler.
schedule(&self, task: task::Notified<Arc<Self>>)966     fn schedule(&self, task: task::Notified<Arc<Self>>) {
967         CURRENT.with(|localdata| {
968             match localdata.ctx.get() {
969                 Some(cx) if cx.shared.ptr_eq(self) => unsafe {
970                     // Safety: if the current `LocalSet` context points to this
971                     // `LocalSet`, then we are on the thread that owns it.
972                     cx.shared.local_state.task_push_back(task);
973                 },
974 
975                 // We are on the thread that owns the `LocalSet`, so we can
976                 // wake to the local queue.
977                 _ if context::thread_id().ok() == Some(self.local_state.owner) => {
978                     unsafe {
979                         // Safety: we just checked that the thread ID matches
980                         // the localset's owner, so this is safe.
981                         self.local_state.task_push_back(task);
982                     }
983                     // We still have to wake the `LocalSet`, because it isn't
984                     // currently being polled.
985                     self.waker.wake();
986                 }
987 
988                 // We are *not* on the thread that owns the `LocalSet`, so we
989                 // have to wake to the remote queue.
990                 _ => {
991                     // First, check whether the queue is still there (if not, the
992                     // LocalSet is dropped). Then push to it if so, and if not,
993                     // do nothing.
994                     let mut lock = self.queue.lock();
995 
996                     if let Some(queue) = lock.as_mut() {
997                         queue.push_back(task);
998                         drop(lock);
999                         self.waker.wake();
1000                     }
1001                 }
1002             }
1003         });
1004     }
1005 
ptr_eq(&self, other: &Shared) -> bool1006     fn ptr_eq(&self, other: &Shared) -> bool {
1007         std::ptr::eq(self, other)
1008     }
1009 }
1010 
1011 // This is safe because (and only because) we *pinky pwomise* to never touch the
1012 // local run queue except from the thread that owns the `LocalSet`.
1013 unsafe impl Sync for Shared {}
1014 
1015 impl task::Schedule for Arc<Shared> {
release(&self, task: &Task<Self>) -> Option<Task<Self>>1016     fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
1017         // Safety, this is always called from the thread that owns `LocalSet`
1018         unsafe { self.local_state.task_remove(task) }
1019     }
1020 
schedule(&self, task: task::Notified<Self>)1021     fn schedule(&self, task: task::Notified<Self>) {
1022         Shared::schedule(self, task);
1023     }
1024 
1025     cfg_unstable! {
1026         fn unhandled_panic(&self) {
1027             use crate::runtime::UnhandledPanic;
1028 
1029             match self.unhandled_panic {
1030                 UnhandledPanic::Ignore => {
1031                     // Do nothing
1032                 }
1033                 UnhandledPanic::ShutdownRuntime => {
1034                     // This hook is only called from within the runtime, so
1035                     // `CURRENT` should match with `&self`, i.e. there is no
1036                     // opportunity for a nested scheduler to be called.
1037                     CURRENT.with(|LocalData { ctx, .. }| match ctx.get() {
1038                         Some(cx) if Arc::ptr_eq(self, &cx.shared) => {
1039                             cx.unhandled_panic.set(true);
1040                             // Safety: this is always called from the thread that owns `LocalSet`
1041                             unsafe { cx.shared.local_state.close_and_shutdown_all(); }
1042                         }
1043                         _ => unreachable!("runtime core not set in CURRENT thread-local"),
1044                     })
1045                 }
1046             }
1047         }
1048     }
1049 }
1050 
1051 impl LocalState {
task_pop_front(&self) -> Option<task::Notified<Arc<Shared>>>1052     unsafe fn task_pop_front(&self) -> Option<task::Notified<Arc<Shared>>> {
1053         // The caller ensures it is called from the same thread that owns
1054         // the LocalSet.
1055         self.assert_called_from_owner_thread();
1056 
1057         self.local_queue.with_mut(|ptr| (*ptr).pop_front())
1058     }
1059 
task_push_back(&self, task: task::Notified<Arc<Shared>>)1060     unsafe fn task_push_back(&self, task: task::Notified<Arc<Shared>>) {
1061         // The caller ensures it is called from the same thread that owns
1062         // the LocalSet.
1063         self.assert_called_from_owner_thread();
1064 
1065         self.local_queue.with_mut(|ptr| (*ptr).push_back(task))
1066     }
1067 
take_local_queue(&self) -> VecDeque<task::Notified<Arc<Shared>>>1068     unsafe fn take_local_queue(&self) -> VecDeque<task::Notified<Arc<Shared>>> {
1069         // The caller ensures it is called from the same thread that owns
1070         // the LocalSet.
1071         self.assert_called_from_owner_thread();
1072 
1073         self.local_queue.with_mut(|ptr| std::mem::take(&mut (*ptr)))
1074     }
1075 
task_remove(&self, task: &Task<Arc<Shared>>) -> Option<Task<Arc<Shared>>>1076     unsafe fn task_remove(&self, task: &Task<Arc<Shared>>) -> Option<Task<Arc<Shared>>> {
1077         // The caller ensures it is called from the same thread that owns
1078         // the LocalSet.
1079         self.assert_called_from_owner_thread();
1080 
1081         self.owned.remove(task)
1082     }
1083 
1084     /// Returns true if the `LocalSet` does not have any spawned tasks
owned_is_empty(&self) -> bool1085     unsafe fn owned_is_empty(&self) -> bool {
1086         // The caller ensures it is called from the same thread that owns
1087         // the LocalSet.
1088         self.assert_called_from_owner_thread();
1089 
1090         self.owned.is_empty()
1091     }
1092 
assert_owner( &self, task: task::Notified<Arc<Shared>>, ) -> task::LocalNotified<Arc<Shared>>1093     unsafe fn assert_owner(
1094         &self,
1095         task: task::Notified<Arc<Shared>>,
1096     ) -> task::LocalNotified<Arc<Shared>> {
1097         // The caller ensures it is called from the same thread that owns
1098         // the LocalSet.
1099         self.assert_called_from_owner_thread();
1100 
1101         self.owned.assert_owner(task)
1102     }
1103 
close_and_shutdown_all(&self)1104     unsafe fn close_and_shutdown_all(&self) {
1105         // The caller ensures it is called from the same thread that owns
1106         // the LocalSet.
1107         self.assert_called_from_owner_thread();
1108 
1109         self.owned.close_and_shutdown_all()
1110     }
1111 
1112     #[track_caller]
assert_called_from_owner_thread(&self)1113     fn assert_called_from_owner_thread(&self) {
1114         // FreeBSD has some weirdness around thread-local destruction.
1115         // TODO: remove this hack when thread id is cleaned up
1116         #[cfg(not(any(target_os = "openbsd", target_os = "freebsd")))]
1117         debug_assert!(
1118             // if we couldn't get the thread ID because we're dropping the local
1119             // data, skip the assertion --- the `Drop` impl is not going to be
1120             // called from another thread, because `LocalSet` is `!Send`
1121             context::thread_id()
1122                 .map(|id| id == self.owner)
1123                 .unwrap_or(true),
1124             "`LocalSet`'s local run queue must not be accessed by another thread!"
1125         );
1126     }
1127 }
1128 
1129 // This is `Send` because it is stored in `Shared`. It is up to the caller to
1130 // ensure they are on the same thread that owns the `LocalSet`.
1131 unsafe impl Send for LocalState {}
1132 
1133 #[cfg(all(test, not(loom)))]
1134 mod tests {
1135     use super::*;
1136 
1137     // Does a `LocalSet` running on a current-thread runtime...basically work?
1138     //
1139     // This duplicates a test in `tests/task_local_set.rs`, but because this is
1140     // a lib test, it wil run under Miri, so this is necessary to catch stacked
1141     // borrows violations in the `LocalSet` implementation.
1142     #[test]
local_current_thread_scheduler()1143     fn local_current_thread_scheduler() {
1144         let f = async {
1145             LocalSet::new()
1146                 .run_until(async {
1147                     spawn_local(async {}).await.unwrap();
1148                 })
1149                 .await;
1150         };
1151         crate::runtime::Builder::new_current_thread()
1152             .build()
1153             .expect("rt")
1154             .block_on(f)
1155     }
1156 
1157     // Tests that when a task on a `LocalSet` is woken by an io driver on the
1158     // same thread, the task is woken to the localset's local queue rather than
1159     // its remote queue.
1160     //
1161     // This test has to be defined in the `local.rs` file as a lib test, rather
1162     // than in `tests/`, because it makes assertions about the local set's
1163     // internal state.
1164     #[test]
wakes_to_local_queue()1165     fn wakes_to_local_queue() {
1166         use super::*;
1167         use crate::sync::Notify;
1168         let rt = crate::runtime::Builder::new_current_thread()
1169             .build()
1170             .expect("rt");
1171         rt.block_on(async {
1172             let local = LocalSet::new();
1173             let notify = Arc::new(Notify::new());
1174             let task = local.spawn_local({
1175                 let notify = notify.clone();
1176                 async move {
1177                     notify.notified().await;
1178                 }
1179             });
1180             let mut run_until = Box::pin(local.run_until(async move {
1181                 task.await.unwrap();
1182             }));
1183 
1184             // poll the run until future once
1185             crate::future::poll_fn(|cx| {
1186                 let _ = run_until.as_mut().poll(cx);
1187                 Poll::Ready(())
1188             })
1189             .await;
1190 
1191             notify.notify_one();
1192             let task = unsafe { local.context.shared.local_state.task_pop_front() };
1193             // TODO(eliza): it would be nice to be able to assert that this is
1194             // the local task.
1195             assert!(
1196                 task.is_some(),
1197                 "task should have been notified to the LocalSet's local queue"
1198             );
1199         })
1200     }
1201 }
1202