• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Run-queue structures to support a work-stealing scheduler
2 
3 use crate::loom::cell::UnsafeCell;
4 use crate::loom::sync::Arc;
5 use crate::runtime::task::{self, Inject};
6 use crate::runtime::MetricsBatch;
7 
8 use std::mem::{self, MaybeUninit};
9 use std::ptr;
10 use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
11 
12 // Use wider integers when possible to increase ABA resilience.
13 //
14 // See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>.
15 cfg_has_atomic_u64! {
16     type UnsignedShort = u32;
17     type UnsignedLong = u64;
18     type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32;
19     type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64;
20 }
21 cfg_not_has_atomic_u64! {
22     type UnsignedShort = u16;
23     type UnsignedLong = u32;
24     type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16;
25     type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32;
26 }
27 
28 /// Producer handle. May only be used from a single thread.
29 pub(crate) struct Local<T: 'static> {
30     inner: Arc<Inner<T>>,
31 }
32 
33 /// Consumer handle. May be used from many threads.
34 pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>);
35 
36 pub(crate) struct Inner<T: 'static> {
37     /// Concurrently updated by many threads.
38     ///
39     /// Contains two `UnsignedShort` values. The LSB byte is the "real" head of
40     /// the queue. The `UnsignedShort` in the MSB is set by a stealer in process
41     /// of stealing values. It represents the first value being stolen in the
42     /// batch. The `UnsignedShort` indices are intentionally wider than strictly
43     /// required for buffer indexing in order to provide ABA mitigation and make
44     /// it possible to distinguish between full and empty buffers.
45     ///
46     /// When both `UnsignedShort` values are the same, there is no active
47     /// stealer.
48     ///
49     /// Tracking an in-progress stealer prevents a wrapping scenario.
50     head: AtomicUnsignedLong,
51 
52     /// Only updated by producer thread but read by many threads.
53     tail: AtomicUnsignedShort,
54 
55     /// Elements
56     buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
57 }
58 
59 unsafe impl<T> Send for Inner<T> {}
60 unsafe impl<T> Sync for Inner<T> {}
61 
62 #[cfg(not(loom))]
63 const LOCAL_QUEUE_CAPACITY: usize = 256;
64 
65 // Shrink the size of the local queue when using loom. This shouldn't impact
66 // logic, but allows loom to test more edge cases in a reasonable a mount of
67 // time.
68 #[cfg(loom)]
69 const LOCAL_QUEUE_CAPACITY: usize = 4;
70 
71 const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
72 
73 // Constructing the fixed size array directly is very awkward. The only way to
74 // do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as
75 // the contents are not Copy. The trick with defining a const doesn't work for
76 // generic types.
make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]>77 fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> {
78     assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY);
79 
80     // safety: We check that the length is correct.
81     unsafe { Box::from_raw(Box::into_raw(buffer).cast()) }
82 }
83 
84 /// Create a new local run-queue
local<T: 'static>() -> (Steal<T>, Local<T>)85 pub(crate) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
86     let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
87 
88     for _ in 0..LOCAL_QUEUE_CAPACITY {
89         buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
90     }
91 
92     let inner = Arc::new(Inner {
93         head: AtomicUnsignedLong::new(0),
94         tail: AtomicUnsignedShort::new(0),
95         buffer: make_fixed_size(buffer.into_boxed_slice()),
96     });
97 
98     let local = Local {
99         inner: inner.clone(),
100     };
101 
102     let remote = Steal(inner);
103 
104     (remote, local)
105 }
106 
107 impl<T> Local<T> {
108     /// Returns true if the queue has entries that can be stolen.
is_stealable(&self) -> bool109     pub(crate) fn is_stealable(&self) -> bool {
110         !self.inner.is_empty()
111     }
112 
113     /// Returns false if there are any entries in the queue
114     ///
115     /// Separate to is_stealable so that refactors of is_stealable to "protect"
116     /// some tasks from stealing won't affect this
has_tasks(&self) -> bool117     pub(crate) fn has_tasks(&self) -> bool {
118         !self.inner.is_empty()
119     }
120 
121     /// Pushes a task to the back of the local queue, skipping the LIFO slot.
push_back( &mut self, mut task: task::Notified<T>, inject: &Inject<T>, metrics: &mut MetricsBatch, )122     pub(crate) fn push_back(
123         &mut self,
124         mut task: task::Notified<T>,
125         inject: &Inject<T>,
126         metrics: &mut MetricsBatch,
127     ) {
128         let tail = loop {
129             let head = self.inner.head.load(Acquire);
130             let (steal, real) = unpack(head);
131 
132             // safety: this is the **only** thread that updates this cell.
133             let tail = unsafe { self.inner.tail.unsync_load() };
134 
135             if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as UnsignedShort {
136                 // There is capacity for the task
137                 break tail;
138             } else if steal != real {
139                 // Concurrently stealing, this will free up capacity, so only
140                 // push the task onto the inject queue
141                 inject.push(task);
142                 return;
143             } else {
144                 // Push the current task and half of the queue into the
145                 // inject queue.
146                 match self.push_overflow(task, real, tail, inject, metrics) {
147                     Ok(_) => return,
148                     // Lost the race, try again
149                     Err(v) => {
150                         task = v;
151                     }
152                 }
153             }
154         };
155 
156         // Map the position to a slot index.
157         let idx = tail as usize & MASK;
158 
159         self.inner.buffer[idx].with_mut(|ptr| {
160             // Write the task to the slot
161             //
162             // Safety: There is only one producer and the above `if`
163             // condition ensures we don't touch a cell if there is a
164             // value, thus no consumer.
165             unsafe {
166                 ptr::write((*ptr).as_mut_ptr(), task);
167             }
168         });
169 
170         // Make the task available. Synchronizes with a load in
171         // `steal_into2`.
172         self.inner.tail.store(tail.wrapping_add(1), Release);
173     }
174 
175     /// Moves a batch of tasks into the inject queue.
176     ///
177     /// This will temporarily make some of the tasks unavailable to stealers.
178     /// Once `push_overflow` is done, a notification is sent out, so if other
179     /// workers "missed" some of the tasks during a steal, they will get
180     /// another opportunity.
181     #[inline(never)]
push_overflow( &mut self, task: task::Notified<T>, head: UnsignedShort, tail: UnsignedShort, inject: &Inject<T>, metrics: &mut MetricsBatch, ) -> Result<(), task::Notified<T>>182     fn push_overflow(
183         &mut self,
184         task: task::Notified<T>,
185         head: UnsignedShort,
186         tail: UnsignedShort,
187         inject: &Inject<T>,
188         metrics: &mut MetricsBatch,
189     ) -> Result<(), task::Notified<T>> {
190         /// How many elements are we taking from the local queue.
191         ///
192         /// This is one less than the number of tasks pushed to the inject
193         /// queue as we are also inserting the `task` argument.
194         const NUM_TASKS_TAKEN: UnsignedShort = (LOCAL_QUEUE_CAPACITY / 2) as UnsignedShort;
195 
196         assert_eq!(
197             tail.wrapping_sub(head) as usize,
198             LOCAL_QUEUE_CAPACITY,
199             "queue is not full; tail = {}; head = {}",
200             tail,
201             head
202         );
203 
204         let prev = pack(head, head);
205 
206         // Claim a bunch of tasks
207         //
208         // We are claiming the tasks **before** reading them out of the buffer.
209         // This is safe because only the **current** thread is able to push new
210         // tasks.
211         //
212         // There isn't really any need for memory ordering... Relaxed would
213         // work. This is because all tasks are pushed into the queue from the
214         // current thread (or memory has been acquired if the local queue handle
215         // moved).
216         if self
217             .inner
218             .head
219             .compare_exchange(
220                 prev,
221                 pack(
222                     head.wrapping_add(NUM_TASKS_TAKEN),
223                     head.wrapping_add(NUM_TASKS_TAKEN),
224                 ),
225                 Release,
226                 Relaxed,
227             )
228             .is_err()
229         {
230             // We failed to claim the tasks, losing the race. Return out of
231             // this function and try the full `push` routine again. The queue
232             // may not be full anymore.
233             return Err(task);
234         }
235 
236         /// An iterator that takes elements out of the run queue.
237         struct BatchTaskIter<'a, T: 'static> {
238             buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
239             head: UnsignedLong,
240             i: UnsignedLong,
241         }
242         impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
243             type Item = task::Notified<T>;
244 
245             #[inline]
246             fn next(&mut self) -> Option<task::Notified<T>> {
247                 if self.i == UnsignedLong::from(NUM_TASKS_TAKEN) {
248                     None
249                 } else {
250                     let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
251                     let slot = &self.buffer[i_idx];
252 
253                     // safety: Our CAS from before has assumed exclusive ownership
254                     // of the task pointers in this range.
255                     let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
256 
257                     self.i += 1;
258                     Some(task)
259                 }
260             }
261         }
262 
263         // safety: The CAS above ensures that no consumer will look at these
264         // values again, and we are the only producer.
265         let batch_iter = BatchTaskIter {
266             buffer: &self.inner.buffer,
267             head: head as UnsignedLong,
268             i: 0,
269         };
270         inject.push_batch(batch_iter.chain(std::iter::once(task)));
271 
272         // Add 1 to factor in the task currently being scheduled.
273         metrics.incr_overflow_count();
274 
275         Ok(())
276     }
277 
278     /// Pops a task from the local queue.
pop(&mut self) -> Option<task::Notified<T>>279     pub(crate) fn pop(&mut self) -> Option<task::Notified<T>> {
280         let mut head = self.inner.head.load(Acquire);
281 
282         let idx = loop {
283             let (steal, real) = unpack(head);
284 
285             // safety: this is the **only** thread that updates this cell.
286             let tail = unsafe { self.inner.tail.unsync_load() };
287 
288             if real == tail {
289                 // queue is empty
290                 return None;
291             }
292 
293             let next_real = real.wrapping_add(1);
294 
295             // If `steal == real` there are no concurrent stealers. Both `steal`
296             // and `real` are updated.
297             let next = if steal == real {
298                 pack(next_real, next_real)
299             } else {
300                 assert_ne!(steal, next_real);
301                 pack(steal, next_real)
302             };
303 
304             // Attempt to claim a task.
305             let res = self
306                 .inner
307                 .head
308                 .compare_exchange(head, next, AcqRel, Acquire);
309 
310             match res {
311                 Ok(_) => break real as usize & MASK,
312                 Err(actual) => head = actual,
313             }
314         };
315 
316         Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
317     }
318 }
319 
320 impl<T> Steal<T> {
is_empty(&self) -> bool321     pub(crate) fn is_empty(&self) -> bool {
322         self.0.is_empty()
323     }
324 
325     /// Steals half the tasks from self and place them into `dst`.
steal_into( &self, dst: &mut Local<T>, dst_metrics: &mut MetricsBatch, ) -> Option<task::Notified<T>>326     pub(crate) fn steal_into(
327         &self,
328         dst: &mut Local<T>,
329         dst_metrics: &mut MetricsBatch,
330     ) -> Option<task::Notified<T>> {
331         // Safety: the caller is the only thread that mutates `dst.tail` and
332         // holds a mutable reference.
333         let dst_tail = unsafe { dst.inner.tail.unsync_load() };
334 
335         // To the caller, `dst` may **look** empty but still have values
336         // contained in the buffer. If another thread is concurrently stealing
337         // from `dst` there may not be enough capacity to steal.
338         let (steal, _) = unpack(dst.inner.head.load(Acquire));
339 
340         if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 {
341             // we *could* try to steal less here, but for simplicity, we're just
342             // going to abort.
343             return None;
344         }
345 
346         // Steal the tasks into `dst`'s buffer. This does not yet expose the
347         // tasks in `dst`.
348         let mut n = self.steal_into2(dst, dst_tail);
349 
350         if n == 0 {
351             // No tasks were stolen
352             return None;
353         }
354 
355         dst_metrics.incr_steal_count(n as u16);
356         dst_metrics.incr_steal_operations();
357 
358         // We are returning a task here
359         n -= 1;
360 
361         let ret_pos = dst_tail.wrapping_add(n);
362         let ret_idx = ret_pos as usize & MASK;
363 
364         // safety: the value was written as part of `steal_into2` and not
365         // exposed to stealers, so no other thread can access it.
366         let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
367 
368         if n == 0 {
369             // The `dst` queue is empty, but a single task was stolen
370             return Some(ret);
371         }
372 
373         // Make the stolen items available to consumers
374         dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
375 
376         Some(ret)
377     }
378 
379     // Steal tasks from `self`, placing them into `dst`. Returns the number of
380     // tasks that were stolen.
steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort381     fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
382         let mut prev_packed = self.0.head.load(Acquire);
383         let mut next_packed;
384 
385         let n = loop {
386             let (src_head_steal, src_head_real) = unpack(prev_packed);
387             let src_tail = self.0.tail.load(Acquire);
388 
389             // If these two do not match, another thread is concurrently
390             // stealing from the queue.
391             if src_head_steal != src_head_real {
392                 return 0;
393             }
394 
395             // Number of available tasks to steal
396             let n = src_tail.wrapping_sub(src_head_real);
397             let n = n - n / 2;
398 
399             if n == 0 {
400                 // No tasks available to steal
401                 return 0;
402             }
403 
404             // Update the real head index to acquire the tasks.
405             let steal_to = src_head_real.wrapping_add(n);
406             assert_ne!(src_head_steal, steal_to);
407             next_packed = pack(src_head_steal, steal_to);
408 
409             // Claim all those tasks. This is done by incrementing the "real"
410             // head but not the steal. By doing this, no other thread is able to
411             // steal from this queue until the current thread completes.
412             let res = self
413                 .0
414                 .head
415                 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
416 
417             match res {
418                 Ok(_) => break n,
419                 Err(actual) => prev_packed = actual,
420             }
421         };
422 
423         assert!(
424             n <= LOCAL_QUEUE_CAPACITY as UnsignedShort / 2,
425             "actual = {}",
426             n
427         );
428 
429         let (first, _) = unpack(next_packed);
430 
431         // Take all the tasks
432         for i in 0..n {
433             // Compute the positions
434             let src_pos = first.wrapping_add(i);
435             let dst_pos = dst_tail.wrapping_add(i);
436 
437             // Map to slots
438             let src_idx = src_pos as usize & MASK;
439             let dst_idx = dst_pos as usize & MASK;
440 
441             // Read the task
442             //
443             // safety: We acquired the task with the atomic exchange above.
444             let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
445 
446             // Write the task to the new slot
447             //
448             // safety: `dst` queue is empty and we are the only producer to
449             // this queue.
450             dst.inner.buffer[dst_idx]
451                 .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
452         }
453 
454         let mut prev_packed = next_packed;
455 
456         // Update `src_head_steal` to match `src_head_real` signalling that the
457         // stealing routine is complete.
458         loop {
459             let head = unpack(prev_packed).1;
460             next_packed = pack(head, head);
461 
462             let res = self
463                 .0
464                 .head
465                 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
466 
467             match res {
468                 Ok(_) => return n,
469                 Err(actual) => {
470                     let (actual_steal, actual_real) = unpack(actual);
471 
472                     assert_ne!(actual_steal, actual_real);
473 
474                     prev_packed = actual;
475                 }
476             }
477         }
478     }
479 }
480 
481 cfg_metrics! {
482     impl<T> Steal<T> {
483         pub(crate) fn len(&self) -> usize {
484             self.0.len() as _
485         }
486     }
487 }
488 
489 impl<T> Clone for Steal<T> {
clone(&self) -> Steal<T>490     fn clone(&self) -> Steal<T> {
491         Steal(self.0.clone())
492     }
493 }
494 
495 impl<T> Drop for Local<T> {
drop(&mut self)496     fn drop(&mut self) {
497         if !std::thread::panicking() {
498             assert!(self.pop().is_none(), "queue not empty");
499         }
500     }
501 }
502 
503 impl<T> Inner<T> {
len(&self) -> UnsignedShort504     fn len(&self) -> UnsignedShort {
505         let (_, head) = unpack(self.head.load(Acquire));
506         let tail = self.tail.load(Acquire);
507 
508         tail.wrapping_sub(head)
509     }
510 
is_empty(&self) -> bool511     fn is_empty(&self) -> bool {
512         self.len() == 0
513     }
514 }
515 
516 /// Split the head value into the real head and the index a stealer is working
517 /// on.
unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort)518 fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) {
519     let real = n & UnsignedShort::MAX as UnsignedLong;
520     let steal = n >> (mem::size_of::<UnsignedShort>() * 8);
521 
522     (steal as UnsignedShort, real as UnsignedShort)
523 }
524 
525 /// Join the two head values
pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong526 fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong {
527     (real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8))
528 }
529 
530 #[test]
test_local_queue_capacity()531 fn test_local_queue_capacity() {
532     assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
533 }
534