• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //! Run-queue structures to support a work-stealing scheduler
2 
3 use crate::loom::cell::UnsafeCell;
4 use crate::loom::sync::atomic::{AtomicU16, AtomicU32};
5 use crate::loom::sync::Arc;
6 use crate::runtime::stats::WorkerStatsBatcher;
7 use crate::runtime::task::{self, Inject};
8 
9 use std::mem::MaybeUninit;
10 use std::ptr;
11 use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
12 
13 /// Producer handle. May only be used from a single thread.
14 pub(super) struct Local<T: 'static> {
15     inner: Arc<Inner<T>>,
16 }
17 
18 /// Consumer handle. May be used from many threads.
19 pub(super) struct Steal<T: 'static>(Arc<Inner<T>>);
20 
21 pub(super) struct Inner<T: 'static> {
22     /// Concurrently updated by many threads.
23     ///
24     /// Contains two `u16` values. The LSB byte is the "real" head of the queue.
25     /// The `u16` in the MSB is set by a stealer in process of stealing values.
26     /// It represents the first value being stolen in the batch. `u16` is used
27     /// in order to distinguish between `head == tail` and `head == tail -
28     /// capacity`.
29     ///
30     /// When both `u16` values are the same, there is no active stealer.
31     ///
32     /// Tracking an in-progress stealer prevents a wrapping scenario.
33     head: AtomicU32,
34 
35     /// Only updated by producer thread but read by many threads.
36     tail: AtomicU16,
37 
38     /// Elements
39     buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
40 }
41 
42 unsafe impl<T> Send for Inner<T> {}
43 unsafe impl<T> Sync for Inner<T> {}
44 
45 #[cfg(not(loom))]
46 const LOCAL_QUEUE_CAPACITY: usize = 256;
47 
48 // Shrink the size of the local queue when using loom. This shouldn't impact
49 // logic, but allows loom to test more edge cases in a reasonable a mount of
50 // time.
51 #[cfg(loom)]
52 const LOCAL_QUEUE_CAPACITY: usize = 4;
53 
54 const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
55 
56 // Constructing the fixed size array directly is very awkward. The only way to
57 // do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as
58 // the contents are not Copy. The trick with defining a const doesn't work for
59 // generic types.
make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]>60 fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> {
61     assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY);
62 
63     // safety: We check that the length is correct.
64     unsafe { Box::from_raw(Box::into_raw(buffer).cast()) }
65 }
66 
67 /// Create a new local run-queue
local<T: 'static>() -> (Steal<T>, Local<T>)68 pub(super) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
69     let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
70 
71     for _ in 0..LOCAL_QUEUE_CAPACITY {
72         buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
73     }
74 
75     let inner = Arc::new(Inner {
76         head: AtomicU32::new(0),
77         tail: AtomicU16::new(0),
78         buffer: make_fixed_size(buffer.into_boxed_slice()),
79     });
80 
81     let local = Local {
82         inner: inner.clone(),
83     };
84 
85     let remote = Steal(inner);
86 
87     (remote, local)
88 }
89 
90 impl<T> Local<T> {
91     /// Returns true if the queue has entries that can be stealed.
is_stealable(&self) -> bool92     pub(super) fn is_stealable(&self) -> bool {
93         !self.inner.is_empty()
94     }
95 
96     /// Returns false if there are any entries in the queue
97     ///
98     /// Separate to is_stealable so that refactors of is_stealable to "protect"
99     /// some tasks from stealing won't affect this
has_tasks(&self) -> bool100     pub(super) fn has_tasks(&self) -> bool {
101         !self.inner.is_empty()
102     }
103 
104     /// Pushes a task to the back of the local queue, skipping the LIFO slot.
push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>)105     pub(super) fn push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>) {
106         let tail = loop {
107             let head = self.inner.head.load(Acquire);
108             let (steal, real) = unpack(head);
109 
110             // safety: this is the **only** thread that updates this cell.
111             let tail = unsafe { self.inner.tail.unsync_load() };
112 
113             if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as u16 {
114                 // There is capacity for the task
115                 break tail;
116             } else if steal != real {
117                 // Concurrently stealing, this will free up capacity, so only
118                 // push the task onto the inject queue
119                 inject.push(task);
120                 return;
121             } else {
122                 // Push the current task and half of the queue into the
123                 // inject queue.
124                 match self.push_overflow(task, real, tail, inject) {
125                     Ok(_) => return,
126                     // Lost the race, try again
127                     Err(v) => {
128                         task = v;
129                     }
130                 }
131             }
132         };
133 
134         // Map the position to a slot index.
135         let idx = tail as usize & MASK;
136 
137         self.inner.buffer[idx].with_mut(|ptr| {
138             // Write the task to the slot
139             //
140             // Safety: There is only one producer and the above `if`
141             // condition ensures we don't touch a cell if there is a
142             // value, thus no consumer.
143             unsafe {
144                 ptr::write((*ptr).as_mut_ptr(), task);
145             }
146         });
147 
148         // Make the task available. Synchronizes with a load in
149         // `steal_into2`.
150         self.inner.tail.store(tail.wrapping_add(1), Release);
151     }
152 
153     /// Moves a batch of tasks into the inject queue.
154     ///
155     /// This will temporarily make some of the tasks unavailable to stealers.
156     /// Once `push_overflow` is done, a notification is sent out, so if other
157     /// workers "missed" some of the tasks during a steal, they will get
158     /// another opportunity.
159     #[inline(never)]
push_overflow( &mut self, task: task::Notified<T>, head: u16, tail: u16, inject: &Inject<T>, ) -> Result<(), task::Notified<T>>160     fn push_overflow(
161         &mut self,
162         task: task::Notified<T>,
163         head: u16,
164         tail: u16,
165         inject: &Inject<T>,
166     ) -> Result<(), task::Notified<T>> {
167         /// How many elements are we taking from the local queue.
168         ///
169         /// This is one less than the number of tasks pushed to the inject
170         /// queue as we are also inserting the `task` argument.
171         const NUM_TASKS_TAKEN: u16 = (LOCAL_QUEUE_CAPACITY / 2) as u16;
172 
173         assert_eq!(
174             tail.wrapping_sub(head) as usize,
175             LOCAL_QUEUE_CAPACITY,
176             "queue is not full; tail = {}; head = {}",
177             tail,
178             head
179         );
180 
181         let prev = pack(head, head);
182 
183         // Claim a bunch of tasks
184         //
185         // We are claiming the tasks **before** reading them out of the buffer.
186         // This is safe because only the **current** thread is able to push new
187         // tasks.
188         //
189         // There isn't really any need for memory ordering... Relaxed would
190         // work. This is because all tasks are pushed into the queue from the
191         // current thread (or memory has been acquired if the local queue handle
192         // moved).
193         if self
194             .inner
195             .head
196             .compare_exchange(
197                 prev,
198                 pack(
199                     head.wrapping_add(NUM_TASKS_TAKEN),
200                     head.wrapping_add(NUM_TASKS_TAKEN),
201                 ),
202                 Release,
203                 Relaxed,
204             )
205             .is_err()
206         {
207             // We failed to claim the tasks, losing the race. Return out of
208             // this function and try the full `push` routine again. The queue
209             // may not be full anymore.
210             return Err(task);
211         }
212 
213         /// An iterator that takes elements out of the run queue.
214         struct BatchTaskIter<'a, T: 'static> {
215             buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
216             head: u32,
217             i: u32,
218         }
219         impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
220             type Item = task::Notified<T>;
221 
222             #[inline]
223             fn next(&mut self) -> Option<task::Notified<T>> {
224                 if self.i == u32::from(NUM_TASKS_TAKEN) {
225                     None
226                 } else {
227                     let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
228                     let slot = &self.buffer[i_idx];
229 
230                     // safety: Our CAS from before has assumed exclusive ownership
231                     // of the task pointers in this range.
232                     let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
233 
234                     self.i += 1;
235                     Some(task)
236                 }
237             }
238         }
239 
240         // safety: The CAS above ensures that no consumer will look at these
241         // values again, and we are the only producer.
242         let batch_iter = BatchTaskIter {
243             buffer: &*self.inner.buffer,
244             head: head as u32,
245             i: 0,
246         };
247         inject.push_batch(batch_iter.chain(std::iter::once(task)));
248 
249         Ok(())
250     }
251 
252     /// Pops a task from the local queue.
pop(&mut self) -> Option<task::Notified<T>>253     pub(super) fn pop(&mut self) -> Option<task::Notified<T>> {
254         let mut head = self.inner.head.load(Acquire);
255 
256         let idx = loop {
257             let (steal, real) = unpack(head);
258 
259             // safety: this is the **only** thread that updates this cell.
260             let tail = unsafe { self.inner.tail.unsync_load() };
261 
262             if real == tail {
263                 // queue is empty
264                 return None;
265             }
266 
267             let next_real = real.wrapping_add(1);
268 
269             // If `steal == real` there are no concurrent stealers. Both `steal`
270             // and `real` are updated.
271             let next = if steal == real {
272                 pack(next_real, next_real)
273             } else {
274                 assert_ne!(steal, next_real);
275                 pack(steal, next_real)
276             };
277 
278             // Attempt to claim a task.
279             let res = self
280                 .inner
281                 .head
282                 .compare_exchange(head, next, AcqRel, Acquire);
283 
284             match res {
285                 Ok(_) => break real as usize & MASK,
286                 Err(actual) => head = actual,
287             }
288         };
289 
290         Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
291     }
292 }
293 
294 impl<T> Steal<T> {
is_empty(&self) -> bool295     pub(super) fn is_empty(&self) -> bool {
296         self.0.is_empty()
297     }
298 
299     /// Steals half the tasks from self and place them into `dst`.
steal_into( &self, dst: &mut Local<T>, stats: &mut WorkerStatsBatcher, ) -> Option<task::Notified<T>>300     pub(super) fn steal_into(
301         &self,
302         dst: &mut Local<T>,
303         stats: &mut WorkerStatsBatcher,
304     ) -> Option<task::Notified<T>> {
305         // Safety: the caller is the only thread that mutates `dst.tail` and
306         // holds a mutable reference.
307         let dst_tail = unsafe { dst.inner.tail.unsync_load() };
308 
309         // To the caller, `dst` may **look** empty but still have values
310         // contained in the buffer. If another thread is concurrently stealing
311         // from `dst` there may not be enough capacity to steal.
312         let (steal, _) = unpack(dst.inner.head.load(Acquire));
313 
314         if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as u16 / 2 {
315             // we *could* try to steal less here, but for simplicity, we're just
316             // going to abort.
317             return None;
318         }
319 
320         // Steal the tasks into `dst`'s buffer. This does not yet expose the
321         // tasks in `dst`.
322         let mut n = self.steal_into2(dst, dst_tail);
323         stats.incr_steal_count(n);
324 
325         if n == 0 {
326             // No tasks were stolen
327             return None;
328         }
329 
330         // We are returning a task here
331         n -= 1;
332 
333         let ret_pos = dst_tail.wrapping_add(n);
334         let ret_idx = ret_pos as usize & MASK;
335 
336         // safety: the value was written as part of `steal_into2` and not
337         // exposed to stealers, so no other thread can access it.
338         let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
339 
340         if n == 0 {
341             // The `dst` queue is empty, but a single task was stolen
342             return Some(ret);
343         }
344 
345         // Make the stolen items available to consumers
346         dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
347 
348         Some(ret)
349     }
350 
351     // Steal tasks from `self`, placing them into `dst`. Returns the number of
352     // tasks that were stolen.
steal_into2(&self, dst: &mut Local<T>, dst_tail: u16) -> u16353     fn steal_into2(&self, dst: &mut Local<T>, dst_tail: u16) -> u16 {
354         let mut prev_packed = self.0.head.load(Acquire);
355         let mut next_packed;
356 
357         let n = loop {
358             let (src_head_steal, src_head_real) = unpack(prev_packed);
359             let src_tail = self.0.tail.load(Acquire);
360 
361             // If these two do not match, another thread is concurrently
362             // stealing from the queue.
363             if src_head_steal != src_head_real {
364                 return 0;
365             }
366 
367             // Number of available tasks to steal
368             let n = src_tail.wrapping_sub(src_head_real);
369             let n = n - n / 2;
370 
371             if n == 0 {
372                 // No tasks available to steal
373                 return 0;
374             }
375 
376             // Update the real head index to acquire the tasks.
377             let steal_to = src_head_real.wrapping_add(n);
378             assert_ne!(src_head_steal, steal_to);
379             next_packed = pack(src_head_steal, steal_to);
380 
381             // Claim all those tasks. This is done by incrementing the "real"
382             // head but not the steal. By doing this, no other thread is able to
383             // steal from this queue until the current thread completes.
384             let res = self
385                 .0
386                 .head
387                 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
388 
389             match res {
390                 Ok(_) => break n,
391                 Err(actual) => prev_packed = actual,
392             }
393         };
394 
395         assert!(n <= LOCAL_QUEUE_CAPACITY as u16 / 2, "actual = {}", n);
396 
397         let (first, _) = unpack(next_packed);
398 
399         // Take all the tasks
400         for i in 0..n {
401             // Compute the positions
402             let src_pos = first.wrapping_add(i);
403             let dst_pos = dst_tail.wrapping_add(i);
404 
405             // Map to slots
406             let src_idx = src_pos as usize & MASK;
407             let dst_idx = dst_pos as usize & MASK;
408 
409             // Read the task
410             //
411             // safety: We acquired the task with the atomic exchange above.
412             let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
413 
414             // Write the task to the new slot
415             //
416             // safety: `dst` queue is empty and we are the only producer to
417             // this queue.
418             dst.inner.buffer[dst_idx]
419                 .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
420         }
421 
422         let mut prev_packed = next_packed;
423 
424         // Update `src_head_steal` to match `src_head_real` signalling that the
425         // stealing routine is complete.
426         loop {
427             let head = unpack(prev_packed).1;
428             next_packed = pack(head, head);
429 
430             let res = self
431                 .0
432                 .head
433                 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
434 
435             match res {
436                 Ok(_) => return n,
437                 Err(actual) => {
438                     let (actual_steal, actual_real) = unpack(actual);
439 
440                     assert_ne!(actual_steal, actual_real);
441 
442                     prev_packed = actual;
443                 }
444             }
445         }
446     }
447 }
448 
449 impl<T> Clone for Steal<T> {
clone(&self) -> Steal<T>450     fn clone(&self) -> Steal<T> {
451         Steal(self.0.clone())
452     }
453 }
454 
455 impl<T> Drop for Local<T> {
drop(&mut self)456     fn drop(&mut self) {
457         if !std::thread::panicking() {
458             assert!(self.pop().is_none(), "queue not empty");
459         }
460     }
461 }
462 
463 impl<T> Inner<T> {
is_empty(&self) -> bool464     fn is_empty(&self) -> bool {
465         let (_, head) = unpack(self.head.load(Acquire));
466         let tail = self.tail.load(Acquire);
467 
468         head == tail
469     }
470 }
471 
472 /// Split the head value into the real head and the index a stealer is working
473 /// on.
unpack(n: u32) -> (u16, u16)474 fn unpack(n: u32) -> (u16, u16) {
475     let real = n & u16::MAX as u32;
476     let steal = n >> 16;
477 
478     (steal as u16, real as u16)
479 }
480 
481 /// Join the two head values
pack(steal: u16, real: u16) -> u32482 fn pack(steal: u16, real: u16) -> u32 {
483     (real as u32) | ((steal as u32) << 16)
484 }
485 
486 #[test]
test_local_queue_capacity()487 fn test_local_queue_capacity() {
488     assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
489 }
490