• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Runs `!Send` futures on the current thread.
2 use crate::loom::cell::UnsafeCell;
3 use crate::loom::sync::{Arc, Mutex};
4 use crate::runtime::task::{self, JoinHandle, LocalOwnedTasks, Task};
5 use crate::runtime::{context, ThreadId};
6 use crate::sync::AtomicWaker;
7 use crate::util::RcCell;
8 
9 use std::cell::Cell;
10 use std::collections::VecDeque;
11 use std::fmt;
12 use std::future::Future;
13 use std::marker::PhantomData;
14 use std::pin::Pin;
15 use std::rc::Rc;
16 use std::task::Poll;
17 
18 use pin_project_lite::pin_project;
19 
20 cfg_rt! {
21     /// A set of tasks which are executed on the same thread.
22     ///
23     /// In some cases, it is necessary to run one or more futures that do not
24     /// implement [`Send`] and thus are unsafe to send between threads. In these
25     /// cases, a [local task set] may be used to schedule one or more `!Send`
26     /// futures to run together on the same thread.
27     ///
28     /// For example, the following code will not compile:
29     ///
30     /// ```rust,compile_fail
31     /// use std::rc::Rc;
32     ///
33     /// #[tokio::main]
34     /// async fn main() {
35     ///     // `Rc` does not implement `Send`, and thus may not be sent between
36     ///     // threads safely.
37     ///     let unsend_data = Rc::new("my unsend data...");
38     ///
39     ///     let unsend_data = unsend_data.clone();
40     ///     // Because the `async` block here moves `unsend_data`, the future is `!Send`.
41     ///     // Since `tokio::spawn` requires the spawned future to implement `Send`, this
42     ///     // will not compile.
43     ///     tokio::spawn(async move {
44     ///         println!("{}", unsend_data);
45     ///         // ...
46     ///     }).await.unwrap();
47     /// }
48     /// ```
49     ///
50     /// # Use with `run_until`
51     ///
52     /// To spawn `!Send` futures, we can use a local task set to schedule them
53     /// on the thread calling [`Runtime::block_on`]. When running inside of the
54     /// local task set, we can use [`task::spawn_local`], which can spawn
55     /// `!Send` futures. For example:
56     ///
57     /// ```rust
58     /// use std::rc::Rc;
59     /// use tokio::task;
60     ///
61     /// #[tokio::main]
62     /// async fn main() {
63     ///     let unsend_data = Rc::new("my unsend data...");
64     ///
65     ///     // Construct a local task set that can run `!Send` futures.
66     ///     let local = task::LocalSet::new();
67     ///
68     ///     // Run the local task set.
69     ///     local.run_until(async move {
70     ///         let unsend_data = unsend_data.clone();
71     ///         // `spawn_local` ensures that the future is spawned on the local
72     ///         // task set.
73     ///         task::spawn_local(async move {
74     ///             println!("{}", unsend_data);
75     ///             // ...
76     ///         }).await.unwrap();
77     ///     }).await;
78     /// }
79     /// ```
80     /// **Note:** The `run_until` method can only be used in `#[tokio::main]`,
81     /// `#[tokio::test]` or directly inside a call to [`Runtime::block_on`]. It
82     /// cannot be used inside a task spawned with `tokio::spawn`.
83     ///
84     /// ## Awaiting a `LocalSet`
85     ///
86     /// Additionally, a `LocalSet` itself implements `Future`, completing when
87     /// *all* tasks spawned on the `LocalSet` complete. This can be used to run
88     /// several futures on a `LocalSet` and drive the whole set until they
89     /// complete. For example,
90     ///
91     /// ```rust
92     /// use tokio::{task, time};
93     /// use std::rc::Rc;
94     ///
95     /// #[tokio::main]
96     /// async fn main() {
97     ///     let unsend_data = Rc::new("world");
98     ///     let local = task::LocalSet::new();
99     ///
100     ///     let unsend_data2 = unsend_data.clone();
101     ///     local.spawn_local(async move {
102     ///         // ...
103     ///         println!("hello {}", unsend_data2)
104     ///     });
105     ///
106     ///     local.spawn_local(async move {
107     ///         time::sleep(time::Duration::from_millis(100)).await;
108     ///         println!("goodbye {}", unsend_data)
109     ///     });
110     ///
111     ///     // ...
112     ///
113     ///     local.await;
114     /// }
115     /// ```
116     /// **Note:** Awaiting a `LocalSet` can only be done inside
117     /// `#[tokio::main]`, `#[tokio::test]` or directly inside a call to
118     /// [`Runtime::block_on`]. It cannot be used inside a task spawned with
119     /// `tokio::spawn`.
120     ///
121     /// ## Use inside `tokio::spawn`
122     ///
123     /// The two methods mentioned above cannot be used inside `tokio::spawn`, so
124     /// to spawn `!Send` futures from inside `tokio::spawn`, we need to do
125     /// something else. The solution is to create the `LocalSet` somewhere else,
126     /// and communicate with it using an [`mpsc`] channel.
127     ///
128     /// The following example puts the `LocalSet` inside a new thread.
129     /// ```
130     /// use tokio::runtime::Builder;
131     /// use tokio::sync::{mpsc, oneshot};
132     /// use tokio::task::LocalSet;
133     ///
134     /// // This struct describes the task you want to spawn. Here we include
135     /// // some simple examples. The oneshot channel allows sending a response
136     /// // to the spawner.
137     /// #[derive(Debug)]
138     /// enum Task {
139     ///     PrintNumber(u32),
140     ///     AddOne(u32, oneshot::Sender<u32>),
141     /// }
142     ///
143     /// #[derive(Clone)]
144     /// struct LocalSpawner {
145     ///    send: mpsc::UnboundedSender<Task>,
146     /// }
147     ///
148     /// impl LocalSpawner {
149     ///     pub fn new() -> Self {
150     ///         let (send, mut recv) = mpsc::unbounded_channel();
151     ///
152     ///         let rt = Builder::new_current_thread()
153     ///             .enable_all()
154     ///             .build()
155     ///             .unwrap();
156     ///
157     ///         std::thread::spawn(move || {
158     ///             let local = LocalSet::new();
159     ///
160     ///             local.spawn_local(async move {
161     ///                 while let Some(new_task) = recv.recv().await {
162     ///                     tokio::task::spawn_local(run_task(new_task));
163     ///                 }
164     ///                 // If the while loop returns, then all the LocalSpawner
165     ///                 // objects have been dropped.
166     ///             });
167     ///
168     ///             // This will return once all senders are dropped and all
169     ///             // spawned tasks have returned.
170     ///             rt.block_on(local);
171     ///         });
172     ///
173     ///         Self {
174     ///             send,
175     ///         }
176     ///     }
177     ///
178     ///     pub fn spawn(&self, task: Task) {
179     ///         self.send.send(task).expect("Thread with LocalSet has shut down.");
180     ///     }
181     /// }
182     ///
183     /// // This task may do !Send stuff. We use printing a number as an example,
184     /// // but it could be anything.
185     /// //
186     /// // The Task struct is an enum to support spawning many different kinds
187     /// // of operations.
188     /// async fn run_task(task: Task) {
189     ///     match task {
190     ///         Task::PrintNumber(n) => {
191     ///             println!("{}", n);
192     ///         },
193     ///         Task::AddOne(n, response) => {
194     ///             // We ignore failures to send the response.
195     ///             let _ = response.send(n + 1);
196     ///         },
197     ///     }
198     /// }
199     ///
200     /// #[tokio::main]
201     /// async fn main() {
202     ///     let spawner = LocalSpawner::new();
203     ///
204     ///     let (send, response) = oneshot::channel();
205     ///     spawner.spawn(Task::AddOne(10, send));
206     ///     let eleven = response.await.unwrap();
207     ///     assert_eq!(eleven, 11);
208     /// }
209     /// ```
210     ///
211     /// [`Send`]: trait@std::marker::Send
212     /// [local task set]: struct@LocalSet
213     /// [`Runtime::block_on`]: method@crate::runtime::Runtime::block_on
214     /// [`task::spawn_local`]: fn@spawn_local
215     /// [`mpsc`]: mod@crate::sync::mpsc
216     pub struct LocalSet {
217         /// Current scheduler tick.
218         tick: Cell<u8>,
219 
220         /// State available from thread-local.
221         context: Rc<Context>,
222 
223         /// This type should not be Send.
224         _not_send: PhantomData<*const ()>,
225     }
226 }
227 
228 /// State available from the thread-local.
229 struct Context {
230     /// State shared between threads.
231     shared: Arc<Shared>,
232 
233     /// True if a task panicked without being handled and the local set is
234     /// configured to shutdown on unhandled panic.
235     unhandled_panic: Cell<bool>,
236 }
237 
238 /// LocalSet state shared between threads.
239 struct Shared {
240     /// # Safety
241     ///
242     /// This field must *only* be accessed from the thread that owns the
243     /// `LocalSet` (i.e., `Thread::current().id() == owner`).
244     local_state: LocalState,
245 
246     /// Remote run queue sender.
247     queue: Mutex<Option<VecDeque<task::Notified<Arc<Shared>>>>>,
248 
249     /// Wake the `LocalSet` task.
250     waker: AtomicWaker,
251 
252     /// How to respond to unhandled task panics.
253     #[cfg(tokio_unstable)]
254     pub(crate) unhandled_panic: crate::runtime::UnhandledPanic,
255 }
256 
257 /// Tracks the `LocalSet` state that must only be accessed from the thread that
258 /// created the `LocalSet`.
259 struct LocalState {
260     /// The `ThreadId` of the thread that owns the `LocalSet`.
261     owner: ThreadId,
262 
263     /// Local run queue sender and receiver.
264     local_queue: UnsafeCell<VecDeque<task::Notified<Arc<Shared>>>>,
265 
266     /// Collection of all active tasks spawned onto this executor.
267     owned: LocalOwnedTasks<Arc<Shared>>,
268 }
269 
270 pin_project! {
271     #[derive(Debug)]
272     struct RunUntil<'a, F> {
273         local_set: &'a LocalSet,
274         #[pin]
275         future: F,
276     }
277 }
278 
279 tokio_thread_local!(static CURRENT: LocalData = const { LocalData {
280     ctx: RcCell::new(),
281 } });
282 
283 struct LocalData {
284     ctx: RcCell<Context>,
285 }
286 
287 cfg_rt! {
288     /// Spawns a `!Send` future on the current [`LocalSet`].
289     ///
290     /// The spawned future will run on the same thread that called `spawn_local`.
291     ///
292     /// The provided future will start running in the background immediately
293     /// when `spawn_local` is called, even if you don't await the returned
294     /// `JoinHandle`.
295     ///
296     /// # Panics
297     ///
298     /// This function panics if called outside of a [`LocalSet`].
299     ///
300     /// Note that if [`tokio::spawn`] is used from within a `LocalSet`, the
301     /// resulting new task will _not_ be inside the `LocalSet`, so you must use
302     /// use `spawn_local` if you want to stay within the `LocalSet`.
303     ///
304     /// # Examples
305     ///
306     /// ```rust
307     /// use std::rc::Rc;
308     /// use tokio::task;
309     ///
310     /// #[tokio::main]
311     /// async fn main() {
312     ///     let unsend_data = Rc::new("my unsend data...");
313     ///
314     ///     let local = task::LocalSet::new();
315     ///
316     ///     // Run the local task set.
317     ///     local.run_until(async move {
318     ///         let unsend_data = unsend_data.clone();
319     ///         task::spawn_local(async move {
320     ///             println!("{}", unsend_data);
321     ///             // ...
322     ///         }).await.unwrap();
323     ///     }).await;
324     /// }
325     /// ```
326     ///
327     /// [`LocalSet`]: struct@crate::task::LocalSet
328     /// [`tokio::spawn`]: fn@crate::task::spawn
329     #[track_caller]
330     pub fn spawn_local<F>(future: F) -> JoinHandle<F::Output>
331     where
332         F: Future + 'static,
333         F::Output: 'static,
334     {
335         spawn_local_inner(future, None)
336     }
337 
338 
339     #[track_caller]
340     pub(super) fn spawn_local_inner<F>(future: F, name: Option<&str>) -> JoinHandle<F::Output>
341     where F: Future + 'static,
342           F::Output: 'static
343     {
344         match CURRENT.with(|LocalData { ctx, .. }| ctx.get()) {
345             None => panic!("`spawn_local` called from outside of a `task::LocalSet`"),
346             Some(cx) => cx.spawn(future, name)
347        }
348     }
349 }
350 
351 /// Initial queue capacity.
352 const INITIAL_CAPACITY: usize = 64;
353 
354 /// Max number of tasks to poll per tick.
355 const MAX_TASKS_PER_TICK: usize = 61;
356 
357 /// How often it check the remote queue first.
358 const REMOTE_FIRST_INTERVAL: u8 = 31;
359 
360 /// Context guard for LocalSet
361 pub struct LocalEnterGuard(Option<Rc<Context>>);
362 
363 impl Drop for LocalEnterGuard {
drop(&mut self)364     fn drop(&mut self) {
365         CURRENT.with(|LocalData { ctx, .. }| {
366             ctx.set(self.0.take());
367         })
368     }
369 }
370 
371 impl fmt::Debug for LocalEnterGuard {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result372     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
373         f.debug_struct("LocalEnterGuard").finish()
374     }
375 }
376 
377 impl LocalSet {
378     /// Returns a new local task set.
new() -> LocalSet379     pub fn new() -> LocalSet {
380         let owner = context::thread_id().expect("cannot create LocalSet during thread shutdown");
381 
382         LocalSet {
383             tick: Cell::new(0),
384             context: Rc::new(Context {
385                 shared: Arc::new(Shared {
386                     local_state: LocalState {
387                         owner,
388                         owned: LocalOwnedTasks::new(),
389                         local_queue: UnsafeCell::new(VecDeque::with_capacity(INITIAL_CAPACITY)),
390                     },
391                     queue: Mutex::new(Some(VecDeque::with_capacity(INITIAL_CAPACITY))),
392                     waker: AtomicWaker::new(),
393                     #[cfg(tokio_unstable)]
394                     unhandled_panic: crate::runtime::UnhandledPanic::Ignore,
395                 }),
396                 unhandled_panic: Cell::new(false),
397             }),
398             _not_send: PhantomData,
399         }
400     }
401 
402     /// Enters the context of this `LocalSet`.
403     ///
404     /// The [`spawn_local`] method will spawn tasks on the `LocalSet` whose
405     /// context you are inside.
406     ///
407     /// [`spawn_local`]: fn@crate::task::spawn_local
enter(&self) -> LocalEnterGuard408     pub fn enter(&self) -> LocalEnterGuard {
409         CURRENT.with(|LocalData { ctx, .. }| {
410             let old = ctx.replace(Some(self.context.clone()));
411             LocalEnterGuard(old)
412         })
413     }
414 
415     /// Spawns a `!Send` task onto the local task set.
416     ///
417     /// This task is guaranteed to be run on the current thread.
418     ///
419     /// Unlike the free function [`spawn_local`], this method may be used to
420     /// spawn local tasks when the `LocalSet` is _not_ running. The provided
421     /// future will start running once the `LocalSet` is next started, even if
422     /// you don't await the returned `JoinHandle`.
423     ///
424     /// # Examples
425     ///
426     /// ```rust
427     /// use tokio::task;
428     ///
429     /// #[tokio::main]
430     /// async fn main() {
431     ///     let local = task::LocalSet::new();
432     ///
433     ///     // Spawn a future on the local set. This future will be run when
434     ///     // we call `run_until` to drive the task set.
435     ///     local.spawn_local(async {
436     ///        // ...
437     ///     });
438     ///
439     ///     // Run the local task set.
440     ///     local.run_until(async move {
441     ///         // ...
442     ///     }).await;
443     ///
444     ///     // When `run` finishes, we can spawn _more_ futures, which will
445     ///     // run in subsequent calls to `run_until`.
446     ///     local.spawn_local(async {
447     ///        // ...
448     ///     });
449     ///
450     ///     local.run_until(async move {
451     ///         // ...
452     ///     }).await;
453     /// }
454     /// ```
455     /// [`spawn_local`]: fn@spawn_local
456     #[track_caller]
spawn_local<F>(&self, future: F) -> JoinHandle<F::Output> where F: Future + 'static, F::Output: 'static,457     pub fn spawn_local<F>(&self, future: F) -> JoinHandle<F::Output>
458     where
459         F: Future + 'static,
460         F::Output: 'static,
461     {
462         self.spawn_named(future, None)
463     }
464 
465     /// Runs a future to completion on the provided runtime, driving any local
466     /// futures spawned on this task set on the current thread.
467     ///
468     /// This runs the given future on the runtime, blocking until it is
469     /// complete, and yielding its resolved result. Any tasks or timers which
470     /// the future spawns internally will be executed on the runtime. The future
471     /// may also call [`spawn_local`] to spawn_local additional local futures on the
472     /// current thread.
473     ///
474     /// This method should not be called from an asynchronous context.
475     ///
476     /// # Panics
477     ///
478     /// This function panics if the executor is at capacity, if the provided
479     /// future panics, or if called within an asynchronous execution context.
480     ///
481     /// # Notes
482     ///
483     /// Since this function internally calls [`Runtime::block_on`], and drives
484     /// futures in the local task set inside that call to `block_on`, the local
485     /// futures may not use [in-place blocking]. If a blocking call needs to be
486     /// issued from a local task, the [`spawn_blocking`] API may be used instead.
487     ///
488     /// For example, this will panic:
489     /// ```should_panic
490     /// use tokio::runtime::Runtime;
491     /// use tokio::task;
492     ///
493     /// let rt  = Runtime::new().unwrap();
494     /// let local = task::LocalSet::new();
495     /// local.block_on(&rt, async {
496     ///     let join = task::spawn_local(async {
497     ///         let blocking_result = task::block_in_place(|| {
498     ///             // ...
499     ///         });
500     ///         // ...
501     ///     });
502     ///     join.await.unwrap();
503     /// })
504     /// ```
505     /// This, however, will not panic:
506     /// ```
507     /// use tokio::runtime::Runtime;
508     /// use tokio::task;
509     ///
510     /// let rt  = Runtime::new().unwrap();
511     /// let local = task::LocalSet::new();
512     /// local.block_on(&rt, async {
513     ///     let join = task::spawn_local(async {
514     ///         let blocking_result = task::spawn_blocking(|| {
515     ///             // ...
516     ///         }).await;
517     ///         // ...
518     ///     });
519     ///     join.await.unwrap();
520     /// })
521     /// ```
522     ///
523     /// [`spawn_local`]: fn@spawn_local
524     /// [`Runtime::block_on`]: method@crate::runtime::Runtime::block_on
525     /// [in-place blocking]: fn@crate::task::block_in_place
526     /// [`spawn_blocking`]: fn@crate::task::spawn_blocking
527     #[track_caller]
528     #[cfg(feature = "rt")]
529     #[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
block_on<F>(&self, rt: &crate::runtime::Runtime, future: F) -> F::Output where F: Future,530     pub fn block_on<F>(&self, rt: &crate::runtime::Runtime, future: F) -> F::Output
531     where
532         F: Future,
533     {
534         rt.block_on(self.run_until(future))
535     }
536 
537     /// Runs a future to completion on the local set, returning its output.
538     ///
539     /// This returns a future that runs the given future with a local set,
540     /// allowing it to call [`spawn_local`] to spawn additional `!Send` futures.
541     /// Any local futures spawned on the local set will be driven in the
542     /// background until the future passed to `run_until` completes. When the future
543     /// passed to `run` finishes, any local futures which have not completed
544     /// will remain on the local set, and will be driven on subsequent calls to
545     /// `run_until` or when [awaiting the local set] itself.
546     ///
547     /// # Examples
548     ///
549     /// ```rust
550     /// use tokio::task;
551     ///
552     /// #[tokio::main]
553     /// async fn main() {
554     ///     task::LocalSet::new().run_until(async {
555     ///         task::spawn_local(async move {
556     ///             // ...
557     ///         }).await.unwrap();
558     ///         // ...
559     ///     }).await;
560     /// }
561     /// ```
562     ///
563     /// [`spawn_local`]: fn@spawn_local
564     /// [awaiting the local set]: #awaiting-a-localset
run_until<F>(&self, future: F) -> F::Output where F: Future,565     pub async fn run_until<F>(&self, future: F) -> F::Output
566     where
567         F: Future,
568     {
569         let run_until = RunUntil {
570             future,
571             local_set: self,
572         };
573         run_until.await
574     }
575 
spawn_named<F>( &self, future: F, name: Option<&str>, ) -> JoinHandle<F::Output> where F: Future + 'static, F::Output: 'static,576     pub(in crate::task) fn spawn_named<F>(
577         &self,
578         future: F,
579         name: Option<&str>,
580     ) -> JoinHandle<F::Output>
581     where
582         F: Future + 'static,
583         F::Output: 'static,
584     {
585         let handle = self.context.spawn(future, name);
586 
587         // Because a task was spawned from *outside* the `LocalSet`, wake the
588         // `LocalSet` future to execute the new task, if it hasn't been woken.
589         //
590         // Spawning via the free fn `spawn` does not require this, as it can
591         // only be called from *within* a future executing on the `LocalSet` —
592         // in that case, the `LocalSet` must already be awake.
593         self.context.shared.waker.wake();
594         handle
595     }
596 
597     /// Ticks the scheduler, returning whether the local future needs to be
598     /// notified again.
tick(&self) -> bool599     fn tick(&self) -> bool {
600         for _ in 0..MAX_TASKS_PER_TICK {
601             // Make sure we didn't hit an unhandled panic
602             if self.context.unhandled_panic.get() {
603                 panic!("a spawned task panicked and the LocalSet is configured to shutdown on unhandled panic");
604             }
605 
606             match self.next_task() {
607                 // Run the task
608                 //
609                 // Safety: As spawned tasks are `!Send`, `run_unchecked` must be
610                 // used. We are responsible for maintaining the invariant that
611                 // `run_unchecked` is only called on threads that spawned the
612                 // task initially. Because `LocalSet` itself is `!Send`, and
613                 // `spawn_local` spawns into the `LocalSet` on the current
614                 // thread, the invariant is maintained.
615                 Some(task) => crate::runtime::coop::budget(|| task.run()),
616                 // We have fully drained the queue of notified tasks, so the
617                 // local future doesn't need to be notified again — it can wait
618                 // until something else wakes a task in the local set.
619                 None => return false,
620             }
621         }
622 
623         true
624     }
625 
next_task(&self) -> Option<task::LocalNotified<Arc<Shared>>>626     fn next_task(&self) -> Option<task::LocalNotified<Arc<Shared>>> {
627         let tick = self.tick.get();
628         self.tick.set(tick.wrapping_add(1));
629 
630         let task = if tick % REMOTE_FIRST_INTERVAL == 0 {
631             self.context
632                 .shared
633                 .queue
634                 .lock()
635                 .as_mut()
636                 .and_then(|queue| queue.pop_front())
637                 .or_else(|| self.pop_local())
638         } else {
639             self.pop_local().or_else(|| {
640                 self.context
641                     .shared
642                     .queue
643                     .lock()
644                     .as_mut()
645                     .and_then(|queue| queue.pop_front())
646             })
647         };
648 
649         task.map(|task| unsafe {
650             // Safety: because the `LocalSet` itself is `!Send`, we know we are
651             // on the same thread if we have access to the `LocalSet`, and can
652             // therefore access the local run queue.
653             self.context.shared.local_state.assert_owner(task)
654         })
655     }
656 
pop_local(&self) -> Option<task::Notified<Arc<Shared>>>657     fn pop_local(&self) -> Option<task::Notified<Arc<Shared>>> {
658         unsafe {
659             // Safety: because the `LocalSet` itself is `!Send`, we know we are
660             // on the same thread if we have access to the `LocalSet`, and can
661             // therefore access the local run queue.
662             self.context.shared.local_state.task_pop_front()
663         }
664     }
665 
with<T>(&self, f: impl FnOnce() -> T) -> T666     fn with<T>(&self, f: impl FnOnce() -> T) -> T {
667         CURRENT.with(|LocalData { ctx, .. }| {
668             struct Reset<'a> {
669                 ctx_ref: &'a RcCell<Context>,
670                 val: Option<Rc<Context>>,
671             }
672             impl<'a> Drop for Reset<'a> {
673                 fn drop(&mut self) {
674                     self.ctx_ref.set(self.val.take());
675                 }
676             }
677             let old = ctx.replace(Some(self.context.clone()));
678 
679             let _reset = Reset {
680                 ctx_ref: ctx,
681                 val: old,
682             };
683 
684             f()
685         })
686     }
687 
688     /// This method is like `with`, but it just calls `f` without setting the thread-local if that
689     /// fails.
with_if_possible<T>(&self, f: impl FnOnce() -> T) -> T690     fn with_if_possible<T>(&self, f: impl FnOnce() -> T) -> T {
691         let mut f = Some(f);
692 
693         let res = CURRENT.try_with(|LocalData { ctx, .. }| {
694             struct Reset<'a> {
695                 ctx_ref: &'a RcCell<Context>,
696                 val: Option<Rc<Context>>,
697             }
698             impl<'a> Drop for Reset<'a> {
699                 fn drop(&mut self) {
700                     self.ctx_ref.replace(self.val.take());
701                 }
702             }
703             let old = ctx.replace(Some(self.context.clone()));
704 
705             let _reset = Reset {
706                 ctx_ref: ctx,
707                 val: old,
708             };
709 
710             (f.take().unwrap())()
711         });
712 
713         match res {
714             Ok(res) => res,
715             Err(_access_error) => (f.take().unwrap())(),
716         }
717     }
718 }
719 
720 cfg_unstable! {
721     impl LocalSet {
722         /// Configure how the `LocalSet` responds to an unhandled panic on a
723         /// spawned task.
724         ///
725         /// By default, an unhandled panic (i.e. a panic not caught by
726         /// [`std::panic::catch_unwind`]) has no impact on the `LocalSet`'s
727         /// execution. The panic is error value is forwarded to the task's
728         /// [`JoinHandle`] and all other spawned tasks continue running.
729         ///
730         /// The `unhandled_panic` option enables configuring this behavior.
731         ///
732         /// * `UnhandledPanic::Ignore` is the default behavior. Panics on
733         ///   spawned tasks have no impact on the `LocalSet`'s execution.
734         /// * `UnhandledPanic::ShutdownRuntime` will force the `LocalSet` to
735         ///   shutdown immediately when a spawned task panics even if that
736         ///   task's `JoinHandle` has not been dropped. All other spawned tasks
737         ///   will immediately terminate and further calls to
738         ///   [`LocalSet::block_on`] and [`LocalSet::run_until`] will panic.
739         ///
740         /// # Panics
741         ///
742         /// This method panics if called after the `LocalSet` has started
743         /// running.
744         ///
745         /// # Unstable
746         ///
747         /// This option is currently unstable and its implementation is
748         /// incomplete. The API may change or be removed in the future. See
749         /// tokio-rs/tokio#4516 for more details.
750         ///
751         /// # Examples
752         ///
753         /// The following demonstrates a `LocalSet` configured to shutdown on
754         /// panic. The first spawned task panics and results in the `LocalSet`
755         /// shutting down. The second spawned task never has a chance to
756         /// execute. The call to `run_until` will panic due to the runtime being
757         /// forcibly shutdown.
758         ///
759         /// ```should_panic
760         /// use tokio::runtime::UnhandledPanic;
761         ///
762         /// # #[tokio::main]
763         /// # async fn main() {
764         /// tokio::task::LocalSet::new()
765         ///     .unhandled_panic(UnhandledPanic::ShutdownRuntime)
766         ///     .run_until(async {
767         ///         tokio::task::spawn_local(async { panic!("boom"); });
768         ///         tokio::task::spawn_local(async {
769         ///             // This task never completes
770         ///         });
771         ///
772         ///         // Do some work, but `run_until` will panic before it completes
773         /// # loop { tokio::task::yield_now().await; }
774         ///     })
775         ///     .await;
776         /// # }
777         /// ```
778         ///
779         /// [`JoinHandle`]: struct@crate::task::JoinHandle
780         pub fn unhandled_panic(&mut self, behavior: crate::runtime::UnhandledPanic) -> &mut Self {
781             // TODO: This should be set as a builder
782             Rc::get_mut(&mut self.context)
783                 .and_then(|ctx| Arc::get_mut(&mut ctx.shared))
784                 .expect("Unhandled Panic behavior modified after starting LocalSet")
785                 .unhandled_panic = behavior;
786             self
787         }
788     }
789 }
790 
791 impl fmt::Debug for LocalSet {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result792     fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
793         fmt.debug_struct("LocalSet").finish()
794     }
795 }
796 
797 impl Future for LocalSet {
798     type Output = ();
799 
poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output>800     fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
801         // Register the waker before starting to work
802         self.context.shared.waker.register_by_ref(cx.waker());
803 
804         if self.with(|| self.tick()) {
805             // If `tick` returns true, we need to notify the local future again:
806             // there are still tasks remaining in the run queue.
807             cx.waker().wake_by_ref();
808             Poll::Pending
809 
810         // Safety: called from the thread that owns `LocalSet`. Because
811         // `LocalSet` is `!Send`, this is safe.
812         } else if unsafe { self.context.shared.local_state.owned_is_empty() } {
813             // If the scheduler has no remaining futures, we're done!
814             Poll::Ready(())
815         } else {
816             // There are still futures in the local set, but we've polled all the
817             // futures in the run queue. Therefore, we can just return Pending
818             // since the remaining futures will be woken from somewhere else.
819             Poll::Pending
820         }
821     }
822 }
823 
824 impl Default for LocalSet {
default() -> LocalSet825     fn default() -> LocalSet {
826         LocalSet::new()
827     }
828 }
829 
830 impl Drop for LocalSet {
drop(&mut self)831     fn drop(&mut self) {
832         self.with_if_possible(|| {
833             // Shut down all tasks in the LocalOwnedTasks and close it to
834             // prevent new tasks from ever being added.
835             unsafe {
836                 // Safety: called from the thread that owns `LocalSet`
837                 self.context.shared.local_state.close_and_shutdown_all();
838             }
839 
840             // We already called shutdown on all tasks above, so there is no
841             // need to call shutdown.
842 
843             // Safety: note that this *intentionally* bypasses the unsafe
844             // `Shared::local_queue()` method. This is in order to avoid the
845             // debug assertion that we are on the thread that owns the
846             // `LocalSet`, because on some systems (e.g. at least some macOS
847             // versions), attempting to get the current thread ID can panic due
848             // to the thread's local data that stores the thread ID being
849             // dropped *before* the `LocalSet`.
850             //
851             // Despite avoiding the assertion here, it is safe for us to access
852             // the local queue in `Drop`, because the `LocalSet` itself is
853             // `!Send`, so we can reasonably guarantee that it will not be
854             // `Drop`ped from another thread.
855             let local_queue = unsafe {
856                 // Safety: called from the thread that owns `LocalSet`
857                 self.context.shared.local_state.take_local_queue()
858             };
859             for task in local_queue {
860                 drop(task);
861             }
862 
863             // Take the queue from the Shared object to prevent pushing
864             // notifications to it in the future.
865             let queue = self.context.shared.queue.lock().take().unwrap();
866             for task in queue {
867                 drop(task);
868             }
869 
870             // Safety: called from the thread that owns `LocalSet`
871             assert!(unsafe { self.context.shared.local_state.owned_is_empty() });
872         });
873     }
874 }
875 
876 // === impl Context ===
877 
878 impl Context {
879     #[track_caller]
spawn<F>(&self, future: F, name: Option<&str>) -> JoinHandle<F::Output> where F: Future + 'static, F::Output: 'static,880     fn spawn<F>(&self, future: F, name: Option<&str>) -> JoinHandle<F::Output>
881     where
882         F: Future + 'static,
883         F::Output: 'static,
884     {
885         let id = crate::runtime::task::Id::next();
886         let future = crate::util::trace::task(future, "local", name, id.as_u64());
887 
888         // Safety: called from the thread that owns the `LocalSet`
889         let (handle, notified) = {
890             self.shared.local_state.assert_called_from_owner_thread();
891             self.shared
892                 .local_state
893                 .owned
894                 .bind(future, self.shared.clone(), id)
895         };
896 
897         if let Some(notified) = notified {
898             self.shared.schedule(notified);
899         }
900 
901         handle
902     }
903 }
904 
905 // === impl LocalFuture ===
906 
907 impl<T: Future> Future for RunUntil<'_, T> {
908     type Output = T::Output;
909 
poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output>910     fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
911         let me = self.project();
912 
913         me.local_set.with(|| {
914             me.local_set
915                 .context
916                 .shared
917                 .waker
918                 .register_by_ref(cx.waker());
919 
920             let _no_blocking = crate::runtime::context::disallow_block_in_place();
921             let f = me.future;
922 
923             if let Poll::Ready(output) = f.poll(cx) {
924                 return Poll::Ready(output);
925             }
926 
927             if me.local_set.tick() {
928                 // If `tick` returns `true`, we need to notify the local future again:
929                 // there are still tasks remaining in the run queue.
930                 cx.waker().wake_by_ref();
931             }
932 
933             Poll::Pending
934         })
935     }
936 }
937 
938 impl Shared {
939     /// Schedule the provided task on the scheduler.
schedule(&self, task: task::Notified<Arc<Self>>)940     fn schedule(&self, task: task::Notified<Arc<Self>>) {
941         CURRENT.with(|localdata| {
942             match localdata.ctx.get() {
943                 Some(cx) if cx.shared.ptr_eq(self) => unsafe {
944                     // Safety: if the current `LocalSet` context points to this
945                     // `LocalSet`, then we are on the thread that owns it.
946                     cx.shared.local_state.task_push_back(task);
947                 },
948 
949                 // We are on the thread that owns the `LocalSet`, so we can
950                 // wake to the local queue.
951                 _ if context::thread_id().ok() == Some(self.local_state.owner) => {
952                     unsafe {
953                         // Safety: we just checked that the thread ID matches
954                         // the localset's owner, so this is safe.
955                         self.local_state.task_push_back(task);
956                     }
957                     // We still have to wake the `LocalSet`, because it isn't
958                     // currently being polled.
959                     self.waker.wake();
960                 }
961 
962                 // We are *not* on the thread that owns the `LocalSet`, so we
963                 // have to wake to the remote queue.
964                 _ => {
965                     // First, check whether the queue is still there (if not, the
966                     // LocalSet is dropped). Then push to it if so, and if not,
967                     // do nothing.
968                     let mut lock = self.queue.lock();
969 
970                     if let Some(queue) = lock.as_mut() {
971                         queue.push_back(task);
972                         drop(lock);
973                         self.waker.wake();
974                     }
975                 }
976             }
977         });
978     }
979 
ptr_eq(&self, other: &Shared) -> bool980     fn ptr_eq(&self, other: &Shared) -> bool {
981         std::ptr::eq(self, other)
982     }
983 }
984 
985 // This is safe because (and only because) we *pinky pwomise* to never touch the
986 // local run queue except from the thread that owns the `LocalSet`.
987 unsafe impl Sync for Shared {}
988 
989 impl task::Schedule for Arc<Shared> {
release(&self, task: &Task<Self>) -> Option<Task<Self>>990     fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
991         // Safety, this is always called from the thread that owns `LocalSet`
992         unsafe { self.local_state.task_remove(task) }
993     }
994 
schedule(&self, task: task::Notified<Self>)995     fn schedule(&self, task: task::Notified<Self>) {
996         Shared::schedule(self, task);
997     }
998 
999     cfg_unstable! {
1000         fn unhandled_panic(&self) {
1001             use crate::runtime::UnhandledPanic;
1002 
1003             match self.unhandled_panic {
1004                 UnhandledPanic::Ignore => {
1005                     // Do nothing
1006                 }
1007                 UnhandledPanic::ShutdownRuntime => {
1008                     // This hook is only called from within the runtime, so
1009                     // `CURRENT` should match with `&self`, i.e. there is no
1010                     // opportunity for a nested scheduler to be called.
1011                     CURRENT.with(|LocalData { ctx, .. }| match ctx.get() {
1012                         Some(cx) if Arc::ptr_eq(self, &cx.shared) => {
1013                             cx.unhandled_panic.set(true);
1014                             // Safety: this is always called from the thread that owns `LocalSet`
1015                             unsafe { cx.shared.local_state.close_and_shutdown_all(); }
1016                         }
1017                         _ => unreachable!("runtime core not set in CURRENT thread-local"),
1018                     })
1019                 }
1020             }
1021         }
1022     }
1023 }
1024 
1025 impl LocalState {
task_pop_front(&self) -> Option<task::Notified<Arc<Shared>>>1026     unsafe fn task_pop_front(&self) -> Option<task::Notified<Arc<Shared>>> {
1027         // The caller ensures it is called from the same thread that owns
1028         // the LocalSet.
1029         self.assert_called_from_owner_thread();
1030 
1031         self.local_queue.with_mut(|ptr| (*ptr).pop_front())
1032     }
1033 
task_push_back(&self, task: task::Notified<Arc<Shared>>)1034     unsafe fn task_push_back(&self, task: task::Notified<Arc<Shared>>) {
1035         // The caller ensures it is called from the same thread that owns
1036         // the LocalSet.
1037         self.assert_called_from_owner_thread();
1038 
1039         self.local_queue.with_mut(|ptr| (*ptr).push_back(task))
1040     }
1041 
take_local_queue(&self) -> VecDeque<task::Notified<Arc<Shared>>>1042     unsafe fn take_local_queue(&self) -> VecDeque<task::Notified<Arc<Shared>>> {
1043         // The caller ensures it is called from the same thread that owns
1044         // the LocalSet.
1045         self.assert_called_from_owner_thread();
1046 
1047         self.local_queue.with_mut(|ptr| std::mem::take(&mut (*ptr)))
1048     }
1049 
task_remove(&self, task: &Task<Arc<Shared>>) -> Option<Task<Arc<Shared>>>1050     unsafe fn task_remove(&self, task: &Task<Arc<Shared>>) -> Option<Task<Arc<Shared>>> {
1051         // The caller ensures it is called from the same thread that owns
1052         // the LocalSet.
1053         self.assert_called_from_owner_thread();
1054 
1055         self.owned.remove(task)
1056     }
1057 
1058     /// Returns true if the `LocalSet` does not have any spawned tasks
owned_is_empty(&self) -> bool1059     unsafe fn owned_is_empty(&self) -> bool {
1060         // The caller ensures it is called from the same thread that owns
1061         // the LocalSet.
1062         self.assert_called_from_owner_thread();
1063 
1064         self.owned.is_empty()
1065     }
1066 
assert_owner( &self, task: task::Notified<Arc<Shared>>, ) -> task::LocalNotified<Arc<Shared>>1067     unsafe fn assert_owner(
1068         &self,
1069         task: task::Notified<Arc<Shared>>,
1070     ) -> task::LocalNotified<Arc<Shared>> {
1071         // The caller ensures it is called from the same thread that owns
1072         // the LocalSet.
1073         self.assert_called_from_owner_thread();
1074 
1075         self.owned.assert_owner(task)
1076     }
1077 
close_and_shutdown_all(&self)1078     unsafe fn close_and_shutdown_all(&self) {
1079         // The caller ensures it is called from the same thread that owns
1080         // the LocalSet.
1081         self.assert_called_from_owner_thread();
1082 
1083         self.owned.close_and_shutdown_all()
1084     }
1085 
1086     #[track_caller]
assert_called_from_owner_thread(&self)1087     fn assert_called_from_owner_thread(&self) {
1088         // FreeBSD has some weirdness around thread-local destruction.
1089         // TODO: remove this hack when thread id is cleaned up
1090         #[cfg(not(any(target_os = "openbsd", target_os = "freebsd")))]
1091         debug_assert!(
1092             // if we couldn't get the thread ID because we're dropping the local
1093             // data, skip the assertion --- the `Drop` impl is not going to be
1094             // called from another thread, because `LocalSet` is `!Send`
1095             context::thread_id()
1096                 .map(|id| id == self.owner)
1097                 .unwrap_or(true),
1098             "`LocalSet`'s local run queue must not be accessed by another thread!"
1099         );
1100     }
1101 }
1102 
1103 // This is `Send` because it is stored in `Shared`. It is up to the caller to
1104 // ensure they are on the same thread that owns the `LocalSet`.
1105 unsafe impl Send for LocalState {}
1106 
1107 #[cfg(all(test, not(loom)))]
1108 mod tests {
1109     use super::*;
1110 
1111     // Does a `LocalSet` running on a current-thread runtime...basically work?
1112     //
1113     // This duplicates a test in `tests/task_local_set.rs`, but because this is
1114     // a lib test, it wil run under Miri, so this is necessary to catch stacked
1115     // borrows violations in the `LocalSet` implementation.
1116     #[test]
local_current_thread_scheduler()1117     fn local_current_thread_scheduler() {
1118         let f = async {
1119             LocalSet::new()
1120                 .run_until(async {
1121                     spawn_local(async {}).await.unwrap();
1122                 })
1123                 .await;
1124         };
1125         crate::runtime::Builder::new_current_thread()
1126             .build()
1127             .expect("rt")
1128             .block_on(f)
1129     }
1130 
1131     // Tests that when a task on a `LocalSet` is woken by an io driver on the
1132     // same thread, the task is woken to the localset's local queue rather than
1133     // its remote queue.
1134     //
1135     // This test has to be defined in the `local.rs` file as a lib test, rather
1136     // than in `tests/`, because it makes assertions about the local set's
1137     // internal state.
1138     #[test]
wakes_to_local_queue()1139     fn wakes_to_local_queue() {
1140         use super::*;
1141         use crate::sync::Notify;
1142         let rt = crate::runtime::Builder::new_current_thread()
1143             .build()
1144             .expect("rt");
1145         rt.block_on(async {
1146             let local = LocalSet::new();
1147             let notify = Arc::new(Notify::new());
1148             let task = local.spawn_local({
1149                 let notify = notify.clone();
1150                 async move {
1151                     notify.notified().await;
1152                 }
1153             });
1154             let mut run_until = Box::pin(local.run_until(async move {
1155                 task.await.unwrap();
1156             }));
1157 
1158             // poll the run until future once
1159             crate::future::poll_fn(|cx| {
1160                 let _ = run_until.as_mut().poll(cx);
1161                 Poll::Ready(())
1162             })
1163             .await;
1164 
1165             notify.notify_one();
1166             let task = unsafe { local.context.shared.local_state.task_pop_front() };
1167             // TODO(eliza): it would be nice to be able to assert that this is
1168             // the local task.
1169             assert!(
1170                 task.is_some(),
1171                 "task should have been notified to the LocalSet's local queue"
1172             );
1173         })
1174     }
1175 }
1176