• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! This module has containers for storing the tasks spawned on a scheduler. The
2 //! `OwnedTasks` container is thread-safe but can only store tasks that
3 //! implement Send. The `LocalOwnedTasks` container is not thread safe, but can
4 //! store non-Send tasks.
5 //!
6 //! The collections can be closed to prevent adding new tasks during shutdown of
7 //! the scheduler with the collection.
8 
9 use crate::future::Future;
10 use crate::loom::cell::UnsafeCell;
11 use crate::loom::sync::Mutex;
12 use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task};
13 use crate::util::linked_list::{CountedLinkedList, Link, LinkedList};
14 
15 use std::marker::PhantomData;
16 use std::num::NonZeroU64;
17 
18 // The id from the module below is used to verify whether a given task is stored
19 // in this OwnedTasks, or some other task. The counter starts at one so we can
20 // use `None` for tasks not owned by any list.
21 //
22 // The safety checks in this file can technically be violated if the counter is
23 // overflown, but the checks are not supposed to ever fail unless there is a
24 // bug in Tokio, so we accept that certain bugs would not be caught if the two
25 // mixed up runtimes happen to have the same id.
26 
27 cfg_has_atomic_u64! {
28     use std::sync::atomic::{AtomicU64, Ordering};
29 
30     static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);
31 
32     fn get_next_id() -> NonZeroU64 {
33         loop {
34             let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
35             if let Some(id) = NonZeroU64::new(id) {
36                 return id;
37             }
38         }
39     }
40 }
41 
42 cfg_not_has_atomic_u64! {
43     use std::sync::atomic::{AtomicU32, Ordering};
44 
45     static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);
46 
47     fn get_next_id() -> NonZeroU64 {
48         loop {
49             let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
50             if let Some(id) = NonZeroU64::new(u64::from(id)) {
51                 return id;
52             }
53         }
54     }
55 }
56 
57 pub(crate) struct OwnedTasks<S: 'static> {
58     inner: Mutex<CountedOwnedTasksInner<S>>,
59     pub(crate) id: NonZeroU64,
60 }
61 struct CountedOwnedTasksInner<S: 'static> {
62     list: CountedLinkedList<Task<S>, <Task<S> as Link>::Target>,
63     closed: bool,
64 }
65 pub(crate) struct LocalOwnedTasks<S: 'static> {
66     inner: UnsafeCell<OwnedTasksInner<S>>,
67     pub(crate) id: NonZeroU64,
68     _not_send_or_sync: PhantomData<*const ()>,
69 }
70 struct OwnedTasksInner<S: 'static> {
71     list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
72     closed: bool,
73 }
74 
75 impl<S: 'static> OwnedTasks<S> {
new() -> Self76     pub(crate) fn new() -> Self {
77         Self {
78             inner: Mutex::new(CountedOwnedTasksInner {
79                 list: CountedLinkedList::new(),
80                 closed: false,
81             }),
82             id: get_next_id(),
83         }
84     }
85 
86     /// Binds the provided task to this OwnedTasks instance. This fails if the
87     /// OwnedTasks has been closed.
bind<T>( &self, task: T, scheduler: S, id: super::Id, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static,88     pub(crate) fn bind<T>(
89         &self,
90         task: T,
91         scheduler: S,
92         id: super::Id,
93     ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
94     where
95         S: Schedule,
96         T: Future + Send + 'static,
97         T::Output: Send + 'static,
98     {
99         let (task, notified, join) = super::new_task(task, scheduler, id);
100         let notified = unsafe { self.bind_inner(task, notified) };
101         (join, notified)
102     }
103 
104     /// The part of `bind` that's the same for every type of future.
bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>> where S: Schedule,105     unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>>
106     where
107         S: Schedule,
108     {
109         unsafe {
110             // safety: We just created the task, so we have exclusive access
111             // to the field.
112             task.header().set_owner_id(self.id);
113         }
114 
115         let mut lock = self.inner.lock();
116         if lock.closed {
117             drop(lock);
118             drop(notified);
119             task.shutdown();
120             None
121         } else {
122             lock.list.push_front(task);
123             Some(notified)
124         }
125     }
126 
127     /// Asserts that the given task is owned by this OwnedTasks and convert it to
128     /// a LocalNotified, giving the thread permission to poll this task.
129     #[inline]
assert_owner(&self, task: Notified<S>) -> LocalNotified<S>130     pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
131         debug_assert_eq!(task.header().get_owner_id(), Some(self.id));
132 
133         // safety: All tasks bound to this OwnedTasks are Send, so it is safe
134         // to poll it on this thread no matter what thread we are on.
135         LocalNotified {
136             task: task.0,
137             _not_send: PhantomData,
138         }
139     }
140 
141     /// Shuts down all tasks in the collection. This call also closes the
142     /// collection, preventing new items from being added.
close_and_shutdown_all(&self) where S: Schedule,143     pub(crate) fn close_and_shutdown_all(&self)
144     where
145         S: Schedule,
146     {
147         // The first iteration of the loop was unrolled so it can set the
148         // closed bool.
149         let first_task = {
150             let mut lock = self.inner.lock();
151             lock.closed = true;
152             lock.list.pop_back()
153         };
154         match first_task {
155             Some(task) => task.shutdown(),
156             None => return,
157         }
158 
159         loop {
160             let task = match self.inner.lock().list.pop_back() {
161                 Some(task) => task,
162                 None => return,
163             };
164 
165             task.shutdown();
166         }
167     }
168 
active_tasks_count(&self) -> usize169     pub(crate) fn active_tasks_count(&self) -> usize {
170         self.inner.lock().list.count()
171     }
172 
remove(&self, task: &Task<S>) -> Option<Task<S>>173     pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
174         // If the task's owner ID is `None` then it is not part of any list and
175         // doesn't need removing.
176         let task_id = task.header().get_owner_id()?;
177 
178         assert_eq!(task_id, self.id);
179 
180         // safety: We just checked that the provided task is not in some other
181         // linked list.
182         unsafe { self.inner.lock().list.remove(task.header_ptr()) }
183     }
184 
is_empty(&self) -> bool185     pub(crate) fn is_empty(&self) -> bool {
186         self.inner.lock().list.is_empty()
187     }
188 }
189 
190 cfg_taskdump! {
191     impl<S: 'static> OwnedTasks<S> {
192         /// Locks the tasks, and calls `f` on an iterator over them.
193         pub(crate) fn for_each<F>(&self, f: F)
194         where
195             F: FnMut(&Task<S>)
196         {
197             self.inner.lock().list.for_each(f)
198         }
199     }
200 }
201 
202 impl<S: 'static> LocalOwnedTasks<S> {
new() -> Self203     pub(crate) fn new() -> Self {
204         Self {
205             inner: UnsafeCell::new(OwnedTasksInner {
206                 list: LinkedList::new(),
207                 closed: false,
208             }),
209             id: get_next_id(),
210             _not_send_or_sync: PhantomData,
211         }
212     }
213 
bind<T>( &self, task: T, scheduler: S, id: super::Id, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + 'static, T::Output: 'static,214     pub(crate) fn bind<T>(
215         &self,
216         task: T,
217         scheduler: S,
218         id: super::Id,
219     ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
220     where
221         S: Schedule,
222         T: Future + 'static,
223         T::Output: 'static,
224     {
225         let (task, notified, join) = super::new_task(task, scheduler, id);
226 
227         unsafe {
228             // safety: We just created the task, so we have exclusive access
229             // to the field.
230             task.header().set_owner_id(self.id);
231         }
232 
233         if self.is_closed() {
234             drop(notified);
235             task.shutdown();
236             (join, None)
237         } else {
238             self.with_inner(|inner| {
239                 inner.list.push_front(task);
240             });
241             (join, Some(notified))
242         }
243     }
244 
245     /// Shuts down all tasks in the collection. This call also closes the
246     /// collection, preventing new items from being added.
close_and_shutdown_all(&self) where S: Schedule,247     pub(crate) fn close_and_shutdown_all(&self)
248     where
249         S: Schedule,
250     {
251         self.with_inner(|inner| inner.closed = true);
252 
253         while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) {
254             task.shutdown();
255         }
256     }
257 
remove(&self, task: &Task<S>) -> Option<Task<S>>258     pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
259         // If the task's owner ID is `None` then it is not part of any list and
260         // doesn't need removing.
261         let task_id = task.header().get_owner_id()?;
262 
263         assert_eq!(task_id, self.id);
264 
265         self.with_inner(|inner|
266             // safety: We just checked that the provided task is not in some
267             // other linked list.
268             unsafe { inner.list.remove(task.header_ptr()) })
269     }
270 
271     /// Asserts that the given task is owned by this LocalOwnedTasks and convert
272     /// it to a LocalNotified, giving the thread permission to poll this task.
273     #[inline]
assert_owner(&self, task: Notified<S>) -> LocalNotified<S>274     pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
275         assert_eq!(task.header().get_owner_id(), Some(self.id));
276 
277         // safety: The task was bound to this LocalOwnedTasks, and the
278         // LocalOwnedTasks is not Send or Sync, so we are on the right thread
279         // for polling this task.
280         LocalNotified {
281             task: task.0,
282             _not_send: PhantomData,
283         }
284     }
285 
286     #[inline]
with_inner<F, T>(&self, f: F) -> T where F: FnOnce(&mut OwnedTasksInner<S>) -> T,287     fn with_inner<F, T>(&self, f: F) -> T
288     where
289         F: FnOnce(&mut OwnedTasksInner<S>) -> T,
290     {
291         // safety: This type is not Sync, so concurrent calls of this method
292         // can't happen.  Furthermore, all uses of this method in this file make
293         // sure that they don't call `with_inner` recursively.
294         self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) })
295     }
296 
is_closed(&self) -> bool297     pub(crate) fn is_closed(&self) -> bool {
298         self.with_inner(|inner| inner.closed)
299     }
300 
is_empty(&self) -> bool301     pub(crate) fn is_empty(&self) -> bool {
302         self.with_inner(|inner| inner.list.is_empty())
303     }
304 }
305 
306 #[cfg(test)]
307 mod tests {
308     use super::*;
309 
310     // This test may run in parallel with other tests, so we only test that ids
311     // come in increasing order.
312     #[test]
test_id_not_broken()313     fn test_id_not_broken() {
314         let mut last_id = get_next_id();
315 
316         for _ in 0..1000 {
317             let next_id = get_next_id();
318             assert!(last_id < next_id);
319             last_id = next_id;
320         }
321     }
322 }
323