• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Run-queue structures to support a work-stealing scheduler
2 
3 use crate::loom::cell::UnsafeCell;
4 use crate::loom::sync::Arc;
5 use crate::runtime::scheduler::multi_thread::{Overflow, Stats};
6 use crate::runtime::task;
7 
8 use std::mem::{self, MaybeUninit};
9 use std::ptr;
10 use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
11 
12 // Use wider integers when possible to increase ABA resilience.
13 //
14 // See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>.
15 cfg_has_atomic_u64! {
16     type UnsignedShort = u32;
17     type UnsignedLong = u64;
18     type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32;
19     type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64;
20 }
21 cfg_not_has_atomic_u64! {
22     type UnsignedShort = u16;
23     type UnsignedLong = u32;
24     type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16;
25     type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32;
26 }
27 
28 /// Producer handle. May only be used from a single thread.
29 pub(crate) struct Local<T: 'static> {
30     inner: Arc<Inner<T>>,
31 }
32 
33 /// Consumer handle. May be used from many threads.
34 pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>);
35 
36 pub(crate) struct Inner<T: 'static> {
37     /// Concurrently updated by many threads.
38     ///
39     /// Contains two `UnsignedShort` values. The LSB byte is the "real" head of
40     /// the queue. The `UnsignedShort` in the MSB is set by a stealer in process
41     /// of stealing values. It represents the first value being stolen in the
42     /// batch. The `UnsignedShort` indices are intentionally wider than strictly
43     /// required for buffer indexing in order to provide ABA mitigation and make
44     /// it possible to distinguish between full and empty buffers.
45     ///
46     /// When both `UnsignedShort` values are the same, there is no active
47     /// stealer.
48     ///
49     /// Tracking an in-progress stealer prevents a wrapping scenario.
50     head: AtomicUnsignedLong,
51 
52     /// Only updated by producer thread but read by many threads.
53     tail: AtomicUnsignedShort,
54 
55     /// Elements
56     buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
57 }
58 
59 unsafe impl<T> Send for Inner<T> {}
60 unsafe impl<T> Sync for Inner<T> {}
61 
62 #[cfg(not(loom))]
63 const LOCAL_QUEUE_CAPACITY: usize = 256;
64 
65 // Shrink the size of the local queue when using loom. This shouldn't impact
66 // logic, but allows loom to test more edge cases in a reasonable a mount of
67 // time.
68 #[cfg(loom)]
69 const LOCAL_QUEUE_CAPACITY: usize = 4;
70 
71 const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
72 
73 // Constructing the fixed size array directly is very awkward. The only way to
74 // do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as
75 // the contents are not Copy. The trick with defining a const doesn't work for
76 // generic types.
make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]>77 fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> {
78     assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY);
79 
80     // safety: We check that the length is correct.
81     unsafe { Box::from_raw(Box::into_raw(buffer).cast()) }
82 }
83 
84 /// Create a new local run-queue
local<T: 'static>() -> (Steal<T>, Local<T>)85 pub(crate) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
86     let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
87 
88     for _ in 0..LOCAL_QUEUE_CAPACITY {
89         buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
90     }
91 
92     let inner = Arc::new(Inner {
93         head: AtomicUnsignedLong::new(0),
94         tail: AtomicUnsignedShort::new(0),
95         buffer: make_fixed_size(buffer.into_boxed_slice()),
96     });
97 
98     let local = Local {
99         inner: inner.clone(),
100     };
101 
102     let remote = Steal(inner);
103 
104     (remote, local)
105 }
106 
107 impl<T> Local<T> {
108     /// Returns the number of entries in the queue
len(&self) -> usize109     pub(crate) fn len(&self) -> usize {
110         self.inner.len() as usize
111     }
112 
113     /// How many tasks can be pushed into the queue
remaining_slots(&self) -> usize114     pub(crate) fn remaining_slots(&self) -> usize {
115         self.inner.remaining_slots()
116     }
117 
max_capacity(&self) -> usize118     pub(crate) fn max_capacity(&self) -> usize {
119         LOCAL_QUEUE_CAPACITY
120     }
121 
122     /// Returns false if there are any entries in the queue
123     ///
124     /// Separate to is_stealable so that refactors of is_stealable to "protect"
125     /// some tasks from stealing won't affect this
has_tasks(&self) -> bool126     pub(crate) fn has_tasks(&self) -> bool {
127         !self.inner.is_empty()
128     }
129 
130     /// Pushes a batch of tasks to the back of the queue. All tasks must fit in
131     /// the local queue.
132     ///
133     /// # Panics
134     ///
135     /// The method panics if there is not enough capacity to fit in the queue.
push_back(&mut self, tasks: impl ExactSizeIterator<Item = task::Notified<T>>)136     pub(crate) fn push_back(&mut self, tasks: impl ExactSizeIterator<Item = task::Notified<T>>) {
137         let len = tasks.len();
138         assert!(len <= LOCAL_QUEUE_CAPACITY);
139 
140         if len == 0 {
141             // Nothing to do
142             return;
143         }
144 
145         let head = self.inner.head.load(Acquire);
146         let (steal, _) = unpack(head);
147 
148         // safety: this is the **only** thread that updates this cell.
149         let mut tail = unsafe { self.inner.tail.unsync_load() };
150 
151         if tail.wrapping_sub(steal) <= (LOCAL_QUEUE_CAPACITY - len) as UnsignedShort {
152             // Yes, this if condition is structured a bit weird (first block
153             // does nothing, second returns an error). It is this way to match
154             // `push_back_or_overflow`.
155         } else {
156             panic!()
157         }
158 
159         for task in tasks {
160             let idx = tail as usize & MASK;
161 
162             self.inner.buffer[idx].with_mut(|ptr| {
163                 // Write the task to the slot
164                 //
165                 // Safety: There is only one producer and the above `if`
166                 // condition ensures we don't touch a cell if there is a
167                 // value, thus no consumer.
168                 unsafe {
169                     ptr::write((*ptr).as_mut_ptr(), task);
170                 }
171             });
172 
173             tail = tail.wrapping_add(1);
174         }
175 
176         self.inner.tail.store(tail, Release);
177     }
178 
179     /// Pushes a task to the back of the local queue, if there is not enough
180     /// capacity in the queue, this triggers the overflow operation.
181     ///
182     /// When the queue overflows, half of the curent contents of the queue is
183     /// moved to the given Injection queue. This frees up capacity for more
184     /// tasks to be pushed into the local queue.
push_back_or_overflow<O: Overflow<T>>( &mut self, mut task: task::Notified<T>, overflow: &O, stats: &mut Stats, )185     pub(crate) fn push_back_or_overflow<O: Overflow<T>>(
186         &mut self,
187         mut task: task::Notified<T>,
188         overflow: &O,
189         stats: &mut Stats,
190     ) {
191         let tail = loop {
192             let head = self.inner.head.load(Acquire);
193             let (steal, real) = unpack(head);
194 
195             // safety: this is the **only** thread that updates this cell.
196             let tail = unsafe { self.inner.tail.unsync_load() };
197 
198             if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as UnsignedShort {
199                 // There is capacity for the task
200                 break tail;
201             } else if steal != real {
202                 // Concurrently stealing, this will free up capacity, so only
203                 // push the task onto the inject queue
204                 overflow.push(task);
205                 return;
206             } else {
207                 // Push the current task and half of the queue into the
208                 // inject queue.
209                 match self.push_overflow(task, real, tail, overflow, stats) {
210                     Ok(_) => return,
211                     // Lost the race, try again
212                     Err(v) => {
213                         task = v;
214                     }
215                 }
216             }
217         };
218 
219         self.push_back_finish(task, tail);
220     }
221 
222     // Second half of `push_back`
push_back_finish(&self, task: task::Notified<T>, tail: UnsignedShort)223     fn push_back_finish(&self, task: task::Notified<T>, tail: UnsignedShort) {
224         // Map the position to a slot index.
225         let idx = tail as usize & MASK;
226 
227         self.inner.buffer[idx].with_mut(|ptr| {
228             // Write the task to the slot
229             //
230             // Safety: There is only one producer and the above `if`
231             // condition ensures we don't touch a cell if there is a
232             // value, thus no consumer.
233             unsafe {
234                 ptr::write((*ptr).as_mut_ptr(), task);
235             }
236         });
237 
238         // Make the task available. Synchronizes with a load in
239         // `steal_into2`.
240         self.inner.tail.store(tail.wrapping_add(1), Release);
241     }
242 
243     /// Moves a batch of tasks into the inject queue.
244     ///
245     /// This will temporarily make some of the tasks unavailable to stealers.
246     /// Once `push_overflow` is done, a notification is sent out, so if other
247     /// workers "missed" some of the tasks during a steal, they will get
248     /// another opportunity.
249     #[inline(never)]
push_overflow<O: Overflow<T>>( &mut self, task: task::Notified<T>, head: UnsignedShort, tail: UnsignedShort, overflow: &O, stats: &mut Stats, ) -> Result<(), task::Notified<T>>250     fn push_overflow<O: Overflow<T>>(
251         &mut self,
252         task: task::Notified<T>,
253         head: UnsignedShort,
254         tail: UnsignedShort,
255         overflow: &O,
256         stats: &mut Stats,
257     ) -> Result<(), task::Notified<T>> {
258         /// How many elements are we taking from the local queue.
259         ///
260         /// This is one less than the number of tasks pushed to the inject
261         /// queue as we are also inserting the `task` argument.
262         const NUM_TASKS_TAKEN: UnsignedShort = (LOCAL_QUEUE_CAPACITY / 2) as UnsignedShort;
263 
264         assert_eq!(
265             tail.wrapping_sub(head) as usize,
266             LOCAL_QUEUE_CAPACITY,
267             "queue is not full; tail = {}; head = {}",
268             tail,
269             head
270         );
271 
272         let prev = pack(head, head);
273 
274         // Claim a bunch of tasks
275         //
276         // We are claiming the tasks **before** reading them out of the buffer.
277         // This is safe because only the **current** thread is able to push new
278         // tasks.
279         //
280         // There isn't really any need for memory ordering... Relaxed would
281         // work. This is because all tasks are pushed into the queue from the
282         // current thread (or memory has been acquired if the local queue handle
283         // moved).
284         if self
285             .inner
286             .head
287             .compare_exchange(
288                 prev,
289                 pack(
290                     head.wrapping_add(NUM_TASKS_TAKEN),
291                     head.wrapping_add(NUM_TASKS_TAKEN),
292                 ),
293                 Release,
294                 Relaxed,
295             )
296             .is_err()
297         {
298             // We failed to claim the tasks, losing the race. Return out of
299             // this function and try the full `push` routine again. The queue
300             // may not be full anymore.
301             return Err(task);
302         }
303 
304         /// An iterator that takes elements out of the run queue.
305         struct BatchTaskIter<'a, T: 'static> {
306             buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
307             head: UnsignedLong,
308             i: UnsignedLong,
309         }
310         impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
311             type Item = task::Notified<T>;
312 
313             #[inline]
314             fn next(&mut self) -> Option<task::Notified<T>> {
315                 if self.i == UnsignedLong::from(NUM_TASKS_TAKEN) {
316                     None
317                 } else {
318                     let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
319                     let slot = &self.buffer[i_idx];
320 
321                     // safety: Our CAS from before has assumed exclusive ownership
322                     // of the task pointers in this range.
323                     let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
324 
325                     self.i += 1;
326                     Some(task)
327                 }
328             }
329         }
330 
331         // safety: The CAS above ensures that no consumer will look at these
332         // values again, and we are the only producer.
333         let batch_iter = BatchTaskIter {
334             buffer: &self.inner.buffer,
335             head: head as UnsignedLong,
336             i: 0,
337         };
338         overflow.push_batch(batch_iter.chain(std::iter::once(task)));
339 
340         // Add 1 to factor in the task currently being scheduled.
341         stats.incr_overflow_count();
342 
343         Ok(())
344     }
345 
346     /// Pops a task from the local queue.
pop(&mut self) -> Option<task::Notified<T>>347     pub(crate) fn pop(&mut self) -> Option<task::Notified<T>> {
348         let mut head = self.inner.head.load(Acquire);
349 
350         let idx = loop {
351             let (steal, real) = unpack(head);
352 
353             // safety: this is the **only** thread that updates this cell.
354             let tail = unsafe { self.inner.tail.unsync_load() };
355 
356             if real == tail {
357                 // queue is empty
358                 return None;
359             }
360 
361             let next_real = real.wrapping_add(1);
362 
363             // If `steal == real` there are no concurrent stealers. Both `steal`
364             // and `real` are updated.
365             let next = if steal == real {
366                 pack(next_real, next_real)
367             } else {
368                 assert_ne!(steal, next_real);
369                 pack(steal, next_real)
370             };
371 
372             // Attempt to claim a task.
373             let res = self
374                 .inner
375                 .head
376                 .compare_exchange(head, next, AcqRel, Acquire);
377 
378             match res {
379                 Ok(_) => break real as usize & MASK,
380                 Err(actual) => head = actual,
381             }
382         };
383 
384         Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
385     }
386 }
387 
388 impl<T> Steal<T> {
is_empty(&self) -> bool389     pub(crate) fn is_empty(&self) -> bool {
390         self.0.is_empty()
391     }
392 
393     /// Steals half the tasks from self and place them into `dst`.
steal_into( &self, dst: &mut Local<T>, dst_stats: &mut Stats, ) -> Option<task::Notified<T>>394     pub(crate) fn steal_into(
395         &self,
396         dst: &mut Local<T>,
397         dst_stats: &mut Stats,
398     ) -> Option<task::Notified<T>> {
399         // Safety: the caller is the only thread that mutates `dst.tail` and
400         // holds a mutable reference.
401         let dst_tail = unsafe { dst.inner.tail.unsync_load() };
402 
403         // To the caller, `dst` may **look** empty but still have values
404         // contained in the buffer. If another thread is concurrently stealing
405         // from `dst` there may not be enough capacity to steal.
406         let (steal, _) = unpack(dst.inner.head.load(Acquire));
407 
408         if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 {
409             // we *could* try to steal less here, but for simplicity, we're just
410             // going to abort.
411             return None;
412         }
413 
414         // Steal the tasks into `dst`'s buffer. This does not yet expose the
415         // tasks in `dst`.
416         let mut n = self.steal_into2(dst, dst_tail);
417 
418         if n == 0 {
419             // No tasks were stolen
420             return None;
421         }
422 
423         dst_stats.incr_steal_count(n as u16);
424         dst_stats.incr_steal_operations();
425 
426         // We are returning a task here
427         n -= 1;
428 
429         let ret_pos = dst_tail.wrapping_add(n);
430         let ret_idx = ret_pos as usize & MASK;
431 
432         // safety: the value was written as part of `steal_into2` and not
433         // exposed to stealers, so no other thread can access it.
434         let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
435 
436         if n == 0 {
437             // The `dst` queue is empty, but a single task was stolen
438             return Some(ret);
439         }
440 
441         // Make the stolen items available to consumers
442         dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
443 
444         Some(ret)
445     }
446 
447     // Steal tasks from `self`, placing them into `dst`. Returns the number of
448     // tasks that were stolen.
steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort449     fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
450         let mut prev_packed = self.0.head.load(Acquire);
451         let mut next_packed;
452 
453         let n = loop {
454             let (src_head_steal, src_head_real) = unpack(prev_packed);
455             let src_tail = self.0.tail.load(Acquire);
456 
457             // If these two do not match, another thread is concurrently
458             // stealing from the queue.
459             if src_head_steal != src_head_real {
460                 return 0;
461             }
462 
463             // Number of available tasks to steal
464             let n = src_tail.wrapping_sub(src_head_real);
465             let n = n - n / 2;
466 
467             if n == 0 {
468                 // No tasks available to steal
469                 return 0;
470             }
471 
472             // Update the real head index to acquire the tasks.
473             let steal_to = src_head_real.wrapping_add(n);
474             assert_ne!(src_head_steal, steal_to);
475             next_packed = pack(src_head_steal, steal_to);
476 
477             // Claim all those tasks. This is done by incrementing the "real"
478             // head but not the steal. By doing this, no other thread is able to
479             // steal from this queue until the current thread completes.
480             let res = self
481                 .0
482                 .head
483                 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
484 
485             match res {
486                 Ok(_) => break n,
487                 Err(actual) => prev_packed = actual,
488             }
489         };
490 
491         assert!(
492             n <= LOCAL_QUEUE_CAPACITY as UnsignedShort / 2,
493             "actual = {}",
494             n
495         );
496 
497         let (first, _) = unpack(next_packed);
498 
499         // Take all the tasks
500         for i in 0..n {
501             // Compute the positions
502             let src_pos = first.wrapping_add(i);
503             let dst_pos = dst_tail.wrapping_add(i);
504 
505             // Map to slots
506             let src_idx = src_pos as usize & MASK;
507             let dst_idx = dst_pos as usize & MASK;
508 
509             // Read the task
510             //
511             // safety: We acquired the task with the atomic exchange above.
512             let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
513 
514             // Write the task to the new slot
515             //
516             // safety: `dst` queue is empty and we are the only producer to
517             // this queue.
518             dst.inner.buffer[dst_idx]
519                 .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
520         }
521 
522         let mut prev_packed = next_packed;
523 
524         // Update `src_head_steal` to match `src_head_real` signalling that the
525         // stealing routine is complete.
526         loop {
527             let head = unpack(prev_packed).1;
528             next_packed = pack(head, head);
529 
530             let res = self
531                 .0
532                 .head
533                 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
534 
535             match res {
536                 Ok(_) => return n,
537                 Err(actual) => {
538                     let (actual_steal, actual_real) = unpack(actual);
539 
540                     assert_ne!(actual_steal, actual_real);
541 
542                     prev_packed = actual;
543                 }
544             }
545         }
546     }
547 }
548 
549 cfg_metrics! {
550     impl<T> Steal<T> {
551         pub(crate) fn len(&self) -> usize {
552             self.0.len() as _
553         }
554     }
555 }
556 
557 impl<T> Clone for Steal<T> {
clone(&self) -> Steal<T>558     fn clone(&self) -> Steal<T> {
559         Steal(self.0.clone())
560     }
561 }
562 
563 impl<T> Drop for Local<T> {
drop(&mut self)564     fn drop(&mut self) {
565         if !std::thread::panicking() {
566             assert!(self.pop().is_none(), "queue not empty");
567         }
568     }
569 }
570 
571 impl<T> Inner<T> {
remaining_slots(&self) -> usize572     fn remaining_slots(&self) -> usize {
573         let (steal, _) = unpack(self.head.load(Acquire));
574         let tail = self.tail.load(Acquire);
575 
576         LOCAL_QUEUE_CAPACITY - (tail.wrapping_sub(steal) as usize)
577     }
578 
len(&self) -> UnsignedShort579     fn len(&self) -> UnsignedShort {
580         let (_, head) = unpack(self.head.load(Acquire));
581         let tail = self.tail.load(Acquire);
582 
583         tail.wrapping_sub(head)
584     }
585 
is_empty(&self) -> bool586     fn is_empty(&self) -> bool {
587         self.len() == 0
588     }
589 }
590 
591 /// Split the head value into the real head and the index a stealer is working
592 /// on.
unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort)593 fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) {
594     let real = n & UnsignedShort::MAX as UnsignedLong;
595     let steal = n >> (mem::size_of::<UnsignedShort>() * 8);
596 
597     (steal as UnsignedShort, real as UnsignedShort)
598 }
599 
600 /// Join the two head values
pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong601 fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong {
602     (real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8))
603 }
604 
605 #[test]
test_local_queue_capacity()606 fn test_local_queue_capacity() {
607     assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
608 }
609