1 use crate::loom::thread::AccessError;
2 use crate::runtime::coop;
3
4 use std::cell::Cell;
5
6 #[cfg(any(feature = "rt", feature = "macros"))]
7 use crate::util::rand::{FastRand, RngSeed};
8
9 cfg_rt! {
10 use crate::runtime::{scheduler, task::Id, Defer};
11
12 use std::cell::RefCell;
13 use std::marker::PhantomData;
14 use std::time::Duration;
15 }
16
17 struct Context {
18 /// Uniquely identifies the current thread
19 #[cfg(feature = "rt")]
20 thread_id: Cell<Option<ThreadId>>,
21
22 /// Handle to the runtime scheduler running on the current thread.
23 #[cfg(feature = "rt")]
24 handle: RefCell<Option<scheduler::Handle>>,
25
26 #[cfg(feature = "rt")]
27 current_task_id: Cell<Option<Id>>,
28
29 /// Tracks if the current thread is currently driving a runtime.
30 /// Note, that if this is set to "entered", the current scheduler
31 /// handle may not reference the runtime currently executing. This
32 /// is because other runtime handles may be set to current from
33 /// within a runtime.
34 #[cfg(feature = "rt")]
35 runtime: Cell<EnterRuntime>,
36
37 /// Yielded task wakers are stored here and notified after resource drivers
38 /// are polled.
39 #[cfg(feature = "rt")]
40 defer: RefCell<Option<Defer>>,
41
42 #[cfg(any(feature = "rt", feature = "macros"))]
43 rng: FastRand,
44
45 /// Tracks the amount of "work" a task may still do before yielding back to
46 /// the sheduler
47 budget: Cell<coop::Budget>,
48 }
49
50 tokio_thread_local! {
51 static CONTEXT: Context = {
52 Context {
53 #[cfg(feature = "rt")]
54 thread_id: Cell::new(None),
55
56 /// Tracks the current runtime handle to use when spawning,
57 /// accessing drivers, etc...
58 #[cfg(feature = "rt")]
59 handle: RefCell::new(None),
60 #[cfg(feature = "rt")]
61 current_task_id: Cell::new(None),
62
63 /// Tracks if the current thread is currently driving a runtime.
64 /// Note, that if this is set to "entered", the current scheduler
65 /// handle may not reference the runtime currently executing. This
66 /// is because other runtime handles may be set to current from
67 /// within a runtime.
68 #[cfg(feature = "rt")]
69 runtime: Cell::new(EnterRuntime::NotEntered),
70
71 #[cfg(feature = "rt")]
72 defer: RefCell::new(None),
73
74 #[cfg(any(feature = "rt", feature = "macros"))]
75 rng: FastRand::new(RngSeed::new()),
76
77 budget: Cell::new(coop::Budget::unconstrained()),
78 }
79 }
80 }
81
82 #[cfg(feature = "macros")]
thread_rng_n(n: u32) -> u3283 pub(crate) fn thread_rng_n(n: u32) -> u32 {
84 CONTEXT.with(|ctx| ctx.rng.fastrand_n(n))
85 }
86
budget<R>(f: impl FnOnce(&Cell<coop::Budget>) -> R) -> Result<R, AccessError>87 pub(super) fn budget<R>(f: impl FnOnce(&Cell<coop::Budget>) -> R) -> Result<R, AccessError> {
88 CONTEXT.try_with(|ctx| f(&ctx.budget))
89 }
90
91 cfg_rt! {
92 use crate::runtime::{ThreadId, TryCurrentError};
93
94 use std::fmt;
95
96 pub(crate) fn thread_id() -> Result<ThreadId, AccessError> {
97 CONTEXT.try_with(|ctx| {
98 match ctx.thread_id.get() {
99 Some(id) => id,
100 None => {
101 let id = ThreadId::next();
102 ctx.thread_id.set(Some(id));
103 id
104 }
105 }
106 })
107 }
108
109 #[derive(Debug, Clone, Copy)]
110 #[must_use]
111 pub(crate) enum EnterRuntime {
112 /// Currently in a runtime context.
113 #[cfg_attr(not(feature = "rt"), allow(dead_code))]
114 Entered { allow_block_in_place: bool },
115
116 /// Not in a runtime context **or** a blocking region.
117 NotEntered,
118 }
119
120 #[derive(Debug)]
121 #[must_use]
122 pub(crate) struct SetCurrentGuard {
123 old_handle: Option<scheduler::Handle>,
124 old_seed: RngSeed,
125 }
126
127 /// Guard tracking that a caller has entered a runtime context.
128 #[must_use]
129 pub(crate) struct EnterRuntimeGuard {
130 /// Tracks that the current thread has entered a blocking function call.
131 pub(crate) blocking: BlockingRegionGuard,
132
133 #[allow(dead_code)] // Only tracking the guard.
134 pub(crate) handle: SetCurrentGuard,
135
136 /// If true, then this is the root runtime guard. It is possible to nest
137 /// runtime guards by using `block_in_place` between the calls. We need
138 /// to track the root guard as this is the guard responsible for freeing
139 /// the deferred task queue.
140 is_root: bool,
141 }
142
143 /// Guard tracking that a caller has entered a blocking region.
144 #[must_use]
145 pub(crate) struct BlockingRegionGuard {
146 _p: PhantomData<RefCell<()>>,
147 }
148
149 pub(crate) struct DisallowBlockInPlaceGuard(bool);
150
151 pub(crate) fn set_current_task_id(id: Option<Id>) -> Option<Id> {
152 CONTEXT.try_with(|ctx| ctx.current_task_id.replace(id)).unwrap_or(None)
153 }
154
155 pub(crate) fn current_task_id() -> Option<Id> {
156 CONTEXT.try_with(|ctx| ctx.current_task_id.get()).unwrap_or(None)
157 }
158
159 pub(crate) fn try_current() -> Result<scheduler::Handle, TryCurrentError> {
160 match CONTEXT.try_with(|ctx| ctx.handle.borrow().clone()) {
161 Ok(Some(handle)) => Ok(handle),
162 Ok(None) => Err(TryCurrentError::new_no_context()),
163 Err(_access_error) => Err(TryCurrentError::new_thread_local_destroyed()),
164 }
165 }
166
167 /// Sets this [`Handle`] as the current active [`Handle`].
168 ///
169 /// [`Handle`]: crate::runtime::scheduler::Handle
170 pub(crate) fn try_set_current(handle: &scheduler::Handle) -> Option<SetCurrentGuard> {
171 CONTEXT.try_with(|ctx| ctx.set_current(handle)).ok()
172 }
173
174
175 /// Marks the current thread as being within the dynamic extent of an
176 /// executor.
177 #[track_caller]
178 pub(crate) fn enter_runtime(handle: &scheduler::Handle, allow_block_in_place: bool) -> EnterRuntimeGuard {
179 if let Some(enter) = try_enter_runtime(handle, allow_block_in_place) {
180 return enter;
181 }
182
183 panic!(
184 "Cannot start a runtime from within a runtime. This happens \
185 because a function (like `block_on`) attempted to block the \
186 current thread while the thread is being used to drive \
187 asynchronous tasks."
188 );
189 }
190
191 /// Tries to enter a runtime context, returns `None` if already in a runtime
192 /// context.
193 fn try_enter_runtime(handle: &scheduler::Handle, allow_block_in_place: bool) -> Option<EnterRuntimeGuard> {
194 CONTEXT.with(|c| {
195 if c.runtime.get().is_entered() {
196 None
197 } else {
198 // Set the entered flag
199 c.runtime.set(EnterRuntime::Entered { allow_block_in_place });
200
201 // Initialize queue to track yielded tasks
202 let mut defer = c.defer.borrow_mut();
203
204 let is_root = if defer.is_none() {
205 *defer = Some(Defer::new());
206 true
207 } else {
208 false
209 };
210
211 Some(EnterRuntimeGuard {
212 blocking: BlockingRegionGuard::new(),
213 handle: c.set_current(handle),
214 is_root,
215 })
216 }
217 })
218 }
219
220 pub(crate) fn try_enter_blocking_region() -> Option<BlockingRegionGuard> {
221 CONTEXT.try_with(|c| {
222 if c.runtime.get().is_entered() {
223 None
224 } else {
225 Some(BlockingRegionGuard::new())
226 }
227 // If accessing the thread-local fails, the thread is terminating
228 // and thread-locals are being destroyed. Because we don't know if
229 // we are currently in a runtime or not, we default to being
230 // permissive.
231 }).unwrap_or_else(|_| Some(BlockingRegionGuard::new()))
232 }
233
234 /// Disallows blocking in the current runtime context until the guard is dropped.
235 pub(crate) fn disallow_block_in_place() -> DisallowBlockInPlaceGuard {
236 let reset = CONTEXT.with(|c| {
237 if let EnterRuntime::Entered {
238 allow_block_in_place: true,
239 } = c.runtime.get()
240 {
241 c.runtime.set(EnterRuntime::Entered {
242 allow_block_in_place: false,
243 });
244 true
245 } else {
246 false
247 }
248 });
249
250 DisallowBlockInPlaceGuard(reset)
251 }
252
253 pub(crate) fn with_defer<R>(f: impl FnOnce(&mut Defer) -> R) -> Option<R> {
254 CONTEXT.with(|c| {
255 let mut defer = c.defer.borrow_mut();
256 defer.as_mut().map(f)
257 })
258 }
259
260 impl Context {
261 fn set_current(&self, handle: &scheduler::Handle) -> SetCurrentGuard {
262 let rng_seed = handle.seed_generator().next_seed();
263
264 let old_handle = self.handle.borrow_mut().replace(handle.clone());
265 let old_seed = self.rng.replace_seed(rng_seed);
266
267 SetCurrentGuard {
268 old_handle,
269 old_seed,
270 }
271 }
272 }
273
274 impl Drop for SetCurrentGuard {
275 fn drop(&mut self) {
276 CONTEXT.with(|ctx| {
277 *ctx.handle.borrow_mut() = self.old_handle.take();
278 ctx.rng.replace_seed(self.old_seed.clone());
279 });
280 }
281 }
282
283 impl fmt::Debug for EnterRuntimeGuard {
284 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285 f.debug_struct("Enter").finish()
286 }
287 }
288
289 impl Drop for EnterRuntimeGuard {
290 fn drop(&mut self) {
291 CONTEXT.with(|c| {
292 assert!(c.runtime.get().is_entered());
293 c.runtime.set(EnterRuntime::NotEntered);
294
295 if self.is_root {
296 *c.defer.borrow_mut() = None;
297 }
298 });
299 }
300 }
301
302 impl BlockingRegionGuard {
303 fn new() -> BlockingRegionGuard {
304 BlockingRegionGuard { _p: PhantomData }
305 }
306 /// Blocks the thread on the specified future, returning the value with
307 /// which that future completes.
308 pub(crate) fn block_on<F>(&mut self, f: F) -> Result<F::Output, AccessError>
309 where
310 F: std::future::Future,
311 {
312 use crate::runtime::park::CachedParkThread;
313
314 let mut park = CachedParkThread::new();
315 park.block_on(f)
316 }
317
318 /// Blocks the thread on the specified future for **at most** `timeout`
319 ///
320 /// If the future completes before `timeout`, the result is returned. If
321 /// `timeout` elapses, then `Err` is returned.
322 pub(crate) fn block_on_timeout<F>(&mut self, f: F, timeout: Duration) -> Result<F::Output, ()>
323 where
324 F: std::future::Future,
325 {
326 use crate::runtime::park::CachedParkThread;
327 use std::task::Context;
328 use std::task::Poll::Ready;
329 use std::time::Instant;
330
331 let mut park = CachedParkThread::new();
332 let waker = park.waker().map_err(|_| ())?;
333 let mut cx = Context::from_waker(&waker);
334
335 pin!(f);
336 let when = Instant::now() + timeout;
337
338 loop {
339 if let Ready(v) = crate::runtime::coop::budget(|| f.as_mut().poll(&mut cx)) {
340 return Ok(v);
341 }
342
343 let now = Instant::now();
344
345 if now >= when {
346 return Err(());
347 }
348
349 // Wake any yielded tasks before parking in order to avoid
350 // blocking.
351 with_defer(|defer| defer.wake());
352
353 park.park_timeout(when - now);
354 }
355 }
356 }
357
358 impl Drop for DisallowBlockInPlaceGuard {
359 fn drop(&mut self) {
360 if self.0 {
361 // XXX: Do we want some kind of assertion here, or is "best effort" okay?
362 CONTEXT.with(|c| {
363 if let EnterRuntime::Entered {
364 allow_block_in_place: false,
365 } = c.runtime.get()
366 {
367 c.runtime.set(EnterRuntime::Entered {
368 allow_block_in_place: true,
369 });
370 }
371 })
372 }
373 }
374 }
375
376 impl EnterRuntime {
377 pub(crate) fn is_entered(self) -> bool {
378 matches!(self, EnterRuntime::Entered { .. })
379 }
380 }
381 }
382
383 // Forces the current "entered" state to be cleared while the closure
384 // is executed.
385 //
386 // # Warning
387 //
388 // This is hidden for a reason. Do not use without fully understanding
389 // executors. Misusing can easily cause your program to deadlock.
390 cfg_rt_multi_thread! {
391 /// Returns true if in a runtime context.
392 pub(crate) fn current_enter_context() -> EnterRuntime {
393 CONTEXT.with(|c| c.runtime.get())
394 }
395
396 pub(crate) fn exit_runtime<F: FnOnce() -> R, R>(f: F) -> R {
397 // Reset in case the closure panics
398 struct Reset(EnterRuntime);
399
400 impl Drop for Reset {
401 fn drop(&mut self) {
402 CONTEXT.with(|c| {
403 assert!(!c.runtime.get().is_entered(), "closure claimed permanent executor");
404 c.runtime.set(self.0);
405 });
406 }
407 }
408
409 let was = CONTEXT.with(|c| {
410 let e = c.runtime.get();
411 assert!(e.is_entered(), "asked to exit when not entered");
412 c.runtime.set(EnterRuntime::NotEntered);
413 e
414 });
415
416 let _reset = Reset(was);
417 // dropping _reset after f() will reset ENTERED
418 f()
419 }
420 }
421