• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Run-queue structures to support a work-stealing scheduler
2 
3 use crate::loom::cell::UnsafeCell;
4 use crate::loom::sync::Arc;
5 use crate::runtime::scheduler::multi_thread::{Overflow, Stats};
6 use crate::runtime::task;
7 
8 use std::mem::{self, MaybeUninit};
9 use std::ptr;
10 use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
11 
12 // Use wider integers when possible to increase ABA resilience.
13 //
14 // See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>.
15 cfg_has_atomic_u64! {
16     type UnsignedShort = u32;
17     type UnsignedLong = u64;
18     type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32;
19     type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64;
20 }
21 cfg_not_has_atomic_u64! {
22     type UnsignedShort = u16;
23     type UnsignedLong = u32;
24     type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16;
25     type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32;
26 }
27 
28 /// Producer handle. May only be used from a single thread.
29 pub(crate) struct Local<T: 'static> {
30     inner: Arc<Inner<T>>,
31 }
32 
33 /// Consumer handle. May be used from many threads.
34 pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>);
35 
36 pub(crate) struct Inner<T: 'static> {
37     /// Concurrently updated by many threads.
38     ///
39     /// Contains two `UnsignedShort` values. The `LSB` byte is the "real" head of
40     /// the queue. The `UnsignedShort` in the `MSB` is set by a stealer in process
41     /// of stealing values. It represents the first value being stolen in the
42     /// batch. The `UnsignedShort` indices are intentionally wider than strictly
43     /// required for buffer indexing in order to provide ABA mitigation and make
44     /// it possible to distinguish between full and empty buffers.
45     ///
46     /// When both `UnsignedShort` values are the same, there is no active
47     /// stealer.
48     ///
49     /// Tracking an in-progress stealer prevents a wrapping scenario.
50     head: AtomicUnsignedLong,
51 
52     /// Only updated by producer thread but read by many threads.
53     tail: AtomicUnsignedShort,
54 
55     /// Elements
56     buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
57 }
58 
59 unsafe impl<T> Send for Inner<T> {}
60 unsafe impl<T> Sync for Inner<T> {}
61 
62 #[cfg(not(loom))]
63 const LOCAL_QUEUE_CAPACITY: usize = 256;
64 
65 // Shrink the size of the local queue when using loom. This shouldn't impact
66 // logic, but allows loom to test more edge cases in a reasonable a mount of
67 // time.
68 #[cfg(loom)]
69 const LOCAL_QUEUE_CAPACITY: usize = 4;
70 
71 const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
72 
73 // Constructing the fixed size array directly is very awkward. The only way to
74 // do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as
75 // the contents are not Copy. The trick with defining a const doesn't work for
76 // generic types.
make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]>77 fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> {
78     assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY);
79 
80     // safety: We check that the length is correct.
81     unsafe { Box::from_raw(Box::into_raw(buffer).cast()) }
82 }
83 
84 /// Create a new local run-queue
local<T: 'static>() -> (Steal<T>, Local<T>)85 pub(crate) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
86     let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
87 
88     for _ in 0..LOCAL_QUEUE_CAPACITY {
89         buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
90     }
91 
92     let inner = Arc::new(Inner {
93         head: AtomicUnsignedLong::new(0),
94         tail: AtomicUnsignedShort::new(0),
95         buffer: make_fixed_size(buffer.into_boxed_slice()),
96     });
97 
98     let local = Local {
99         inner: inner.clone(),
100     };
101 
102     let remote = Steal(inner);
103 
104     (remote, local)
105 }
106 
107 impl<T> Local<T> {
108     /// Returns the number of entries in the queue
len(&self) -> usize109     pub(crate) fn len(&self) -> usize {
110         self.inner.len() as usize
111     }
112 
113     /// How many tasks can be pushed into the queue
remaining_slots(&self) -> usize114     pub(crate) fn remaining_slots(&self) -> usize {
115         self.inner.remaining_slots()
116     }
117 
max_capacity(&self) -> usize118     pub(crate) fn max_capacity(&self) -> usize {
119         LOCAL_QUEUE_CAPACITY
120     }
121 
122     /// Returns false if there are any entries in the queue
123     ///
124     /// Separate to `is_stealable` so that refactors of `is_stealable` to "protect"
125     /// some tasks from stealing won't affect this
has_tasks(&self) -> bool126     pub(crate) fn has_tasks(&self) -> bool {
127         !self.inner.is_empty()
128     }
129 
130     /// Pushes a batch of tasks to the back of the queue. All tasks must fit in
131     /// the local queue.
132     ///
133     /// # Panics
134     ///
135     /// The method panics if there is not enough capacity to fit in the queue.
push_back(&mut self, tasks: impl ExactSizeIterator<Item = task::Notified<T>>)136     pub(crate) fn push_back(&mut self, tasks: impl ExactSizeIterator<Item = task::Notified<T>>) {
137         let len = tasks.len();
138         assert!(len <= LOCAL_QUEUE_CAPACITY);
139 
140         if len == 0 {
141             // Nothing to do
142             return;
143         }
144 
145         let head = self.inner.head.load(Acquire);
146         let (steal, _) = unpack(head);
147 
148         // safety: this is the **only** thread that updates this cell.
149         let mut tail = unsafe { self.inner.tail.unsync_load() };
150 
151         if tail.wrapping_sub(steal) <= (LOCAL_QUEUE_CAPACITY - len) as UnsignedShort {
152             // Yes, this if condition is structured a bit weird (first block
153             // does nothing, second returns an error). It is this way to match
154             // `push_back_or_overflow`.
155         } else {
156             panic!()
157         }
158 
159         for task in tasks {
160             let idx = tail as usize & MASK;
161 
162             self.inner.buffer[idx].with_mut(|ptr| {
163                 // Write the task to the slot
164                 //
165                 // Safety: There is only one producer and the above `if`
166                 // condition ensures we don't touch a cell if there is a
167                 // value, thus no consumer.
168                 unsafe {
169                     ptr::write((*ptr).as_mut_ptr(), task);
170                 }
171             });
172 
173             tail = tail.wrapping_add(1);
174         }
175 
176         self.inner.tail.store(tail, Release);
177     }
178 
179     /// Pushes a task to the back of the local queue, if there is not enough
180     /// capacity in the queue, this triggers the overflow operation.
181     ///
182     /// When the queue overflows, half of the current contents of the queue is
183     /// moved to the given Injection queue. This frees up capacity for more
184     /// tasks to be pushed into the local queue.
push_back_or_overflow<O: Overflow<T>>( &mut self, mut task: task::Notified<T>, overflow: &O, stats: &mut Stats, )185     pub(crate) fn push_back_or_overflow<O: Overflow<T>>(
186         &mut self,
187         mut task: task::Notified<T>,
188         overflow: &O,
189         stats: &mut Stats,
190     ) {
191         let tail = loop {
192             let head = self.inner.head.load(Acquire);
193             let (steal, real) = unpack(head);
194 
195             // safety: this is the **only** thread that updates this cell.
196             let tail = unsafe { self.inner.tail.unsync_load() };
197 
198             if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as UnsignedShort {
199                 // There is capacity for the task
200                 break tail;
201             } else if steal != real {
202                 // Concurrently stealing, this will free up capacity, so only
203                 // push the task onto the inject queue
204                 overflow.push(task);
205                 return;
206             } else {
207                 // Push the current task and half of the queue into the
208                 // inject queue.
209                 match self.push_overflow(task, real, tail, overflow, stats) {
210                     Ok(_) => return,
211                     // Lost the race, try again
212                     Err(v) => {
213                         task = v;
214                     }
215                 }
216             }
217         };
218 
219         self.push_back_finish(task, tail);
220     }
221 
222     // Second half of `push_back`
push_back_finish(&self, task: task::Notified<T>, tail: UnsignedShort)223     fn push_back_finish(&self, task: task::Notified<T>, tail: UnsignedShort) {
224         // Map the position to a slot index.
225         let idx = tail as usize & MASK;
226 
227         self.inner.buffer[idx].with_mut(|ptr| {
228             // Write the task to the slot
229             //
230             // Safety: There is only one producer and the above `if`
231             // condition ensures we don't touch a cell if there is a
232             // value, thus no consumer.
233             unsafe {
234                 ptr::write((*ptr).as_mut_ptr(), task);
235             }
236         });
237 
238         // Make the task available. Synchronizes with a load in
239         // `steal_into2`.
240         self.inner.tail.store(tail.wrapping_add(1), Release);
241     }
242 
243     /// Moves a batch of tasks into the inject queue.
244     ///
245     /// This will temporarily make some of the tasks unavailable to stealers.
246     /// Once `push_overflow` is done, a notification is sent out, so if other
247     /// workers "missed" some of the tasks during a steal, they will get
248     /// another opportunity.
249     #[inline(never)]
push_overflow<O: Overflow<T>>( &mut self, task: task::Notified<T>, head: UnsignedShort, tail: UnsignedShort, overflow: &O, stats: &mut Stats, ) -> Result<(), task::Notified<T>>250     fn push_overflow<O: Overflow<T>>(
251         &mut self,
252         task: task::Notified<T>,
253         head: UnsignedShort,
254         tail: UnsignedShort,
255         overflow: &O,
256         stats: &mut Stats,
257     ) -> Result<(), task::Notified<T>> {
258         /// How many elements are we taking from the local queue.
259         ///
260         /// This is one less than the number of tasks pushed to the inject
261         /// queue as we are also inserting the `task` argument.
262         const NUM_TASKS_TAKEN: UnsignedShort = (LOCAL_QUEUE_CAPACITY / 2) as UnsignedShort;
263 
264         assert_eq!(
265             tail.wrapping_sub(head) as usize,
266             LOCAL_QUEUE_CAPACITY,
267             "queue is not full; tail = {tail}; head = {head}"
268         );
269 
270         let prev = pack(head, head);
271 
272         // Claim a bunch of tasks
273         //
274         // We are claiming the tasks **before** reading them out of the buffer.
275         // This is safe because only the **current** thread is able to push new
276         // tasks.
277         //
278         // There isn't really any need for memory ordering... Relaxed would
279         // work. This is because all tasks are pushed into the queue from the
280         // current thread (or memory has been acquired if the local queue handle
281         // moved).
282         if self
283             .inner
284             .head
285             .compare_exchange(
286                 prev,
287                 pack(
288                     head.wrapping_add(NUM_TASKS_TAKEN),
289                     head.wrapping_add(NUM_TASKS_TAKEN),
290                 ),
291                 Release,
292                 Relaxed,
293             )
294             .is_err()
295         {
296             // We failed to claim the tasks, losing the race. Return out of
297             // this function and try the full `push` routine again. The queue
298             // may not be full anymore.
299             return Err(task);
300         }
301 
302         /// An iterator that takes elements out of the run queue.
303         struct BatchTaskIter<'a, T: 'static> {
304             buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
305             head: UnsignedLong,
306             i: UnsignedLong,
307         }
308         impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
309             type Item = task::Notified<T>;
310 
311             #[inline]
312             fn next(&mut self) -> Option<task::Notified<T>> {
313                 if self.i == UnsignedLong::from(NUM_TASKS_TAKEN) {
314                     None
315                 } else {
316                     let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
317                     let slot = &self.buffer[i_idx];
318 
319                     // safety: Our CAS from before has assumed exclusive ownership
320                     // of the task pointers in this range.
321                     let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
322 
323                     self.i += 1;
324                     Some(task)
325                 }
326             }
327         }
328 
329         // safety: The CAS above ensures that no consumer will look at these
330         // values again, and we are the only producer.
331         let batch_iter = BatchTaskIter {
332             buffer: &self.inner.buffer,
333             head: head as UnsignedLong,
334             i: 0,
335         };
336         overflow.push_batch(batch_iter.chain(std::iter::once(task)));
337 
338         // Add 1 to factor in the task currently being scheduled.
339         stats.incr_overflow_count();
340 
341         Ok(())
342     }
343 
344     /// Pops a task from the local queue.
pop(&mut self) -> Option<task::Notified<T>>345     pub(crate) fn pop(&mut self) -> Option<task::Notified<T>> {
346         let mut head = self.inner.head.load(Acquire);
347 
348         let idx = loop {
349             let (steal, real) = unpack(head);
350 
351             // safety: this is the **only** thread that updates this cell.
352             let tail = unsafe { self.inner.tail.unsync_load() };
353 
354             if real == tail {
355                 // queue is empty
356                 return None;
357             }
358 
359             let next_real = real.wrapping_add(1);
360 
361             // If `steal == real` there are no concurrent stealers. Both `steal`
362             // and `real` are updated.
363             let next = if steal == real {
364                 pack(next_real, next_real)
365             } else {
366                 assert_ne!(steal, next_real);
367                 pack(steal, next_real)
368             };
369 
370             // Attempt to claim a task.
371             let res = self
372                 .inner
373                 .head
374                 .compare_exchange(head, next, AcqRel, Acquire);
375 
376             match res {
377                 Ok(_) => break real as usize & MASK,
378                 Err(actual) => head = actual,
379             }
380         };
381 
382         Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
383     }
384 }
385 
386 impl<T> Steal<T> {
is_empty(&self) -> bool387     pub(crate) fn is_empty(&self) -> bool {
388         self.0.is_empty()
389     }
390 
391     /// Steals half the tasks from self and place them into `dst`.
steal_into( &self, dst: &mut Local<T>, dst_stats: &mut Stats, ) -> Option<task::Notified<T>>392     pub(crate) fn steal_into(
393         &self,
394         dst: &mut Local<T>,
395         dst_stats: &mut Stats,
396     ) -> Option<task::Notified<T>> {
397         // Safety: the caller is the only thread that mutates `dst.tail` and
398         // holds a mutable reference.
399         let dst_tail = unsafe { dst.inner.tail.unsync_load() };
400 
401         // To the caller, `dst` may **look** empty but still have values
402         // contained in the buffer. If another thread is concurrently stealing
403         // from `dst` there may not be enough capacity to steal.
404         let (steal, _) = unpack(dst.inner.head.load(Acquire));
405 
406         if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 {
407             // we *could* try to steal less here, but for simplicity, we're just
408             // going to abort.
409             return None;
410         }
411 
412         // Steal the tasks into `dst`'s buffer. This does not yet expose the
413         // tasks in `dst`.
414         let mut n = self.steal_into2(dst, dst_tail);
415 
416         if n == 0 {
417             // No tasks were stolen
418             return None;
419         }
420 
421         dst_stats.incr_steal_count(n as u16);
422         dst_stats.incr_steal_operations();
423 
424         // We are returning a task here
425         n -= 1;
426 
427         let ret_pos = dst_tail.wrapping_add(n);
428         let ret_idx = ret_pos as usize & MASK;
429 
430         // safety: the value was written as part of `steal_into2` and not
431         // exposed to stealers, so no other thread can access it.
432         let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
433 
434         if n == 0 {
435             // The `dst` queue is empty, but a single task was stolen
436             return Some(ret);
437         }
438 
439         // Make the stolen items available to consumers
440         dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
441 
442         Some(ret)
443     }
444 
445     // Steal tasks from `self`, placing them into `dst`. Returns the number of
446     // tasks that were stolen.
steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort447     fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
448         let mut prev_packed = self.0.head.load(Acquire);
449         let mut next_packed;
450 
451         let n = loop {
452             let (src_head_steal, src_head_real) = unpack(prev_packed);
453             let src_tail = self.0.tail.load(Acquire);
454 
455             // If these two do not match, another thread is concurrently
456             // stealing from the queue.
457             if src_head_steal != src_head_real {
458                 return 0;
459             }
460 
461             // Number of available tasks to steal
462             let n = src_tail.wrapping_sub(src_head_real);
463             let n = n - n / 2;
464 
465             if n == 0 {
466                 // No tasks available to steal
467                 return 0;
468             }
469 
470             // Update the real head index to acquire the tasks.
471             let steal_to = src_head_real.wrapping_add(n);
472             assert_ne!(src_head_steal, steal_to);
473             next_packed = pack(src_head_steal, steal_to);
474 
475             // Claim all those tasks. This is done by incrementing the "real"
476             // head but not the steal. By doing this, no other thread is able to
477             // steal from this queue until the current thread completes.
478             let res = self
479                 .0
480                 .head
481                 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
482 
483             match res {
484                 Ok(_) => break n,
485                 Err(actual) => prev_packed = actual,
486             }
487         };
488 
489         assert!(
490             n <= LOCAL_QUEUE_CAPACITY as UnsignedShort / 2,
491             "actual = {n}"
492         );
493 
494         let (first, _) = unpack(next_packed);
495 
496         // Take all the tasks
497         for i in 0..n {
498             // Compute the positions
499             let src_pos = first.wrapping_add(i);
500             let dst_pos = dst_tail.wrapping_add(i);
501 
502             // Map to slots
503             let src_idx = src_pos as usize & MASK;
504             let dst_idx = dst_pos as usize & MASK;
505 
506             // Read the task
507             //
508             // safety: We acquired the task with the atomic exchange above.
509             let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
510 
511             // Write the task to the new slot
512             //
513             // safety: `dst` queue is empty and we are the only producer to
514             // this queue.
515             dst.inner.buffer[dst_idx]
516                 .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
517         }
518 
519         let mut prev_packed = next_packed;
520 
521         // Update `src_head_steal` to match `src_head_real` signalling that the
522         // stealing routine is complete.
523         loop {
524             let head = unpack(prev_packed).1;
525             next_packed = pack(head, head);
526 
527             let res = self
528                 .0
529                 .head
530                 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
531 
532             match res {
533                 Ok(_) => return n,
534                 Err(actual) => {
535                     let (actual_steal, actual_real) = unpack(actual);
536 
537                     assert_ne!(actual_steal, actual_real);
538 
539                     prev_packed = actual;
540                 }
541             }
542         }
543     }
544 }
545 
546 cfg_unstable_metrics! {
547     impl<T> Steal<T> {
548         pub(crate) fn len(&self) -> usize {
549             self.0.len() as _
550         }
551     }
552 }
553 
554 impl<T> Clone for Steal<T> {
clone(&self) -> Steal<T>555     fn clone(&self) -> Steal<T> {
556         Steal(self.0.clone())
557     }
558 }
559 
560 impl<T> Drop for Local<T> {
drop(&mut self)561     fn drop(&mut self) {
562         if !std::thread::panicking() {
563             assert!(self.pop().is_none(), "queue not empty");
564         }
565     }
566 }
567 
568 impl<T> Inner<T> {
remaining_slots(&self) -> usize569     fn remaining_slots(&self) -> usize {
570         let (steal, _) = unpack(self.head.load(Acquire));
571         let tail = self.tail.load(Acquire);
572 
573         LOCAL_QUEUE_CAPACITY - (tail.wrapping_sub(steal) as usize)
574     }
575 
len(&self) -> UnsignedShort576     fn len(&self) -> UnsignedShort {
577         let (_, head) = unpack(self.head.load(Acquire));
578         let tail = self.tail.load(Acquire);
579 
580         tail.wrapping_sub(head)
581     }
582 
is_empty(&self) -> bool583     fn is_empty(&self) -> bool {
584         self.len() == 0
585     }
586 }
587 
588 /// Split the head value into the real head and the index a stealer is working
589 /// on.
unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort)590 fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) {
591     let real = n & UnsignedShort::MAX as UnsignedLong;
592     let steal = n >> (mem::size_of::<UnsignedShort>() * 8);
593 
594     (steal as UnsignedShort, real as UnsignedShort)
595 }
596 
597 /// Join the two head values
pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong598 fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong {
599     (real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8))
600 }
601 
602 #[test]
test_local_queue_capacity()603 fn test_local_queue_capacity() {
604     assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
605 }
606