1 //! This module has containers for storing the tasks spawned on a scheduler. The 2 //! `OwnedTasks` container is thread-safe but can only store tasks that 3 //! implement Send. The `LocalOwnedTasks` container is not thread safe, but can 4 //! store non-Send tasks. 5 //! 6 //! The collections can be closed to prevent adding new tasks during shutdown of 7 //! the scheduler with the collection. 8 9 use crate::future::Future; 10 use crate::loom::cell::UnsafeCell; 11 use crate::loom::sync::Mutex; 12 use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task}; 13 use crate::util::linked_list::{CountedLinkedList, Link, LinkedList}; 14 15 use std::marker::PhantomData; 16 use std::num::NonZeroU64; 17 18 // The id from the module below is used to verify whether a given task is stored 19 // in this OwnedTasks, or some other task. The counter starts at one so we can 20 // use `None` for tasks not owned by any list. 21 // 22 // The safety checks in this file can technically be violated if the counter is 23 // overflown, but the checks are not supposed to ever fail unless there is a 24 // bug in Tokio, so we accept that certain bugs would not be caught if the two 25 // mixed up runtimes happen to have the same id. 26 27 cfg_has_atomic_u64! { 28 use std::sync::atomic::{AtomicU64, Ordering}; 29 30 static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1); 31 32 fn get_next_id() -> NonZeroU64 { 33 loop { 34 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); 35 if let Some(id) = NonZeroU64::new(id) { 36 return id; 37 } 38 } 39 } 40 } 41 42 cfg_not_has_atomic_u64! { 43 use std::sync::atomic::{AtomicU32, Ordering}; 44 45 static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1); 46 47 fn get_next_id() -> NonZeroU64 { 48 loop { 49 let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed); 50 if let Some(id) = NonZeroU64::new(u64::from(id)) { 51 return id; 52 } 53 } 54 } 55 } 56 57 pub(crate) struct OwnedTasks<S: 'static> { 58 inner: Mutex<CountedOwnedTasksInner<S>>, 59 pub(crate) id: NonZeroU64, 60 } 61 struct CountedOwnedTasksInner<S: 'static> { 62 list: CountedLinkedList<Task<S>, <Task<S> as Link>::Target>, 63 closed: bool, 64 } 65 pub(crate) struct LocalOwnedTasks<S: 'static> { 66 inner: UnsafeCell<OwnedTasksInner<S>>, 67 pub(crate) id: NonZeroU64, 68 _not_send_or_sync: PhantomData<*const ()>, 69 } 70 struct OwnedTasksInner<S: 'static> { 71 list: LinkedList<Task<S>, <Task<S> as Link>::Target>, 72 closed: bool, 73 } 74 75 impl<S: 'static> OwnedTasks<S> { new() -> Self76 pub(crate) fn new() -> Self { 77 Self { 78 inner: Mutex::new(CountedOwnedTasksInner { 79 list: CountedLinkedList::new(), 80 closed: false, 81 }), 82 id: get_next_id(), 83 } 84 } 85 86 /// Binds the provided task to this OwnedTasks instance. This fails if the 87 /// OwnedTasks has been closed. bind<T>( &self, task: T, scheduler: S, id: super::Id, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static,88 pub(crate) fn bind<T>( 89 &self, 90 task: T, 91 scheduler: S, 92 id: super::Id, 93 ) -> (JoinHandle<T::Output>, Option<Notified<S>>) 94 where 95 S: Schedule, 96 T: Future + Send + 'static, 97 T::Output: Send + 'static, 98 { 99 let (task, notified, join) = super::new_task(task, scheduler, id); 100 let notified = unsafe { self.bind_inner(task, notified) }; 101 (join, notified) 102 } 103 104 /// The part of `bind` that's the same for every type of future. bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>> where S: Schedule,105 unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>> 106 where 107 S: Schedule, 108 { 109 unsafe { 110 // safety: We just created the task, so we have exclusive access 111 // to the field. 112 task.header().set_owner_id(self.id); 113 } 114 115 let mut lock = self.inner.lock(); 116 if lock.closed { 117 drop(lock); 118 drop(notified); 119 task.shutdown(); 120 None 121 } else { 122 lock.list.push_front(task); 123 Some(notified) 124 } 125 } 126 127 /// Asserts that the given task is owned by this OwnedTasks and convert it to 128 /// a LocalNotified, giving the thread permission to poll this task. 129 #[inline] assert_owner(&self, task: Notified<S>) -> LocalNotified<S>130 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { 131 debug_assert_eq!(task.header().get_owner_id(), Some(self.id)); 132 133 // safety: All tasks bound to this OwnedTasks are Send, so it is safe 134 // to poll it on this thread no matter what thread we are on. 135 LocalNotified { 136 task: task.0, 137 _not_send: PhantomData, 138 } 139 } 140 141 /// Shuts down all tasks in the collection. This call also closes the 142 /// collection, preventing new items from being added. close_and_shutdown_all(&self) where S: Schedule,143 pub(crate) fn close_and_shutdown_all(&self) 144 where 145 S: Schedule, 146 { 147 // The first iteration of the loop was unrolled so it can set the 148 // closed bool. 149 let first_task = { 150 let mut lock = self.inner.lock(); 151 lock.closed = true; 152 lock.list.pop_back() 153 }; 154 match first_task { 155 Some(task) => task.shutdown(), 156 None => return, 157 } 158 159 loop { 160 let task = match self.inner.lock().list.pop_back() { 161 Some(task) => task, 162 None => return, 163 }; 164 165 task.shutdown(); 166 } 167 } 168 active_tasks_count(&self) -> usize169 pub(crate) fn active_tasks_count(&self) -> usize { 170 self.inner.lock().list.count() 171 } 172 remove(&self, task: &Task<S>) -> Option<Task<S>>173 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { 174 // If the task's owner ID is `None` then it is not part of any list and 175 // doesn't need removing. 176 let task_id = task.header().get_owner_id()?; 177 178 assert_eq!(task_id, self.id); 179 180 // safety: We just checked that the provided task is not in some other 181 // linked list. 182 unsafe { self.inner.lock().list.remove(task.header_ptr()) } 183 } 184 is_empty(&self) -> bool185 pub(crate) fn is_empty(&self) -> bool { 186 self.inner.lock().list.is_empty() 187 } 188 } 189 190 cfg_taskdump! { 191 impl<S: 'static> OwnedTasks<S> { 192 /// Locks the tasks, and calls `f` on an iterator over them. 193 pub(crate) fn for_each<F>(&self, f: F) 194 where 195 F: FnMut(&Task<S>) 196 { 197 self.inner.lock().list.for_each(f) 198 } 199 } 200 } 201 202 impl<S: 'static> LocalOwnedTasks<S> { new() -> Self203 pub(crate) fn new() -> Self { 204 Self { 205 inner: UnsafeCell::new(OwnedTasksInner { 206 list: LinkedList::new(), 207 closed: false, 208 }), 209 id: get_next_id(), 210 _not_send_or_sync: PhantomData, 211 } 212 } 213 bind<T>( &self, task: T, scheduler: S, id: super::Id, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + 'static, T::Output: 'static,214 pub(crate) fn bind<T>( 215 &self, 216 task: T, 217 scheduler: S, 218 id: super::Id, 219 ) -> (JoinHandle<T::Output>, Option<Notified<S>>) 220 where 221 S: Schedule, 222 T: Future + 'static, 223 T::Output: 'static, 224 { 225 let (task, notified, join) = super::new_task(task, scheduler, id); 226 227 unsafe { 228 // safety: We just created the task, so we have exclusive access 229 // to the field. 230 task.header().set_owner_id(self.id); 231 } 232 233 if self.is_closed() { 234 drop(notified); 235 task.shutdown(); 236 (join, None) 237 } else { 238 self.with_inner(|inner| { 239 inner.list.push_front(task); 240 }); 241 (join, Some(notified)) 242 } 243 } 244 245 /// Shuts down all tasks in the collection. This call also closes the 246 /// collection, preventing new items from being added. close_and_shutdown_all(&self) where S: Schedule,247 pub(crate) fn close_and_shutdown_all(&self) 248 where 249 S: Schedule, 250 { 251 self.with_inner(|inner| inner.closed = true); 252 253 while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) { 254 task.shutdown(); 255 } 256 } 257 remove(&self, task: &Task<S>) -> Option<Task<S>>258 pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> { 259 // If the task's owner ID is `None` then it is not part of any list and 260 // doesn't need removing. 261 let task_id = task.header().get_owner_id()?; 262 263 assert_eq!(task_id, self.id); 264 265 self.with_inner(|inner| 266 // safety: We just checked that the provided task is not in some 267 // other linked list. 268 unsafe { inner.list.remove(task.header_ptr()) }) 269 } 270 271 /// Asserts that the given task is owned by this LocalOwnedTasks and convert 272 /// it to a LocalNotified, giving the thread permission to poll this task. 273 #[inline] assert_owner(&self, task: Notified<S>) -> LocalNotified<S>274 pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> { 275 assert_eq!(task.header().get_owner_id(), Some(self.id)); 276 277 // safety: The task was bound to this LocalOwnedTasks, and the 278 // LocalOwnedTasks is not Send or Sync, so we are on the right thread 279 // for polling this task. 280 LocalNotified { 281 task: task.0, 282 _not_send: PhantomData, 283 } 284 } 285 286 #[inline] with_inner<F, T>(&self, f: F) -> T where F: FnOnce(&mut OwnedTasksInner<S>) -> T,287 fn with_inner<F, T>(&self, f: F) -> T 288 where 289 F: FnOnce(&mut OwnedTasksInner<S>) -> T, 290 { 291 // safety: This type is not Sync, so concurrent calls of this method 292 // can't happen. Furthermore, all uses of this method in this file make 293 // sure that they don't call `with_inner` recursively. 294 self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) }) 295 } 296 is_closed(&self) -> bool297 pub(crate) fn is_closed(&self) -> bool { 298 self.with_inner(|inner| inner.closed) 299 } 300 is_empty(&self) -> bool301 pub(crate) fn is_empty(&self) -> bool { 302 self.with_inner(|inner| inner.list.is_empty()) 303 } 304 } 305 306 #[cfg(test)] 307 mod tests { 308 use super::*; 309 310 // This test may run in parallel with other tests, so we only test that ids 311 // come in increasing order. 312 #[test] test_id_not_broken()313 fn test_id_not_broken() { 314 let mut last_id = get_next_id(); 315 316 for _ in 0..1000 { 317 let next_id = get_next_id(); 318 assert!(last_id < next_id); 319 last_id = next_id; 320 } 321 } 322 } 323