1 //! Run-queue structures to support a work-stealing scheduler
2
3 use crate::loom::cell::UnsafeCell;
4 use crate::loom::sync::atomic::{AtomicU16, AtomicU32};
5 use crate::loom::sync::Arc;
6 use crate::runtime::stats::WorkerStatsBatcher;
7 use crate::runtime::task::{self, Inject};
8
9 use std::mem::MaybeUninit;
10 use std::ptr;
11 use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
12
13 /// Producer handle. May only be used from a single thread.
14 pub(super) struct Local<T: 'static> {
15 inner: Arc<Inner<T>>,
16 }
17
18 /// Consumer handle. May be used from many threads.
19 pub(super) struct Steal<T: 'static>(Arc<Inner<T>>);
20
21 pub(super) struct Inner<T: 'static> {
22 /// Concurrently updated by many threads.
23 ///
24 /// Contains two `u16` values. The LSB byte is the "real" head of the queue.
25 /// The `u16` in the MSB is set by a stealer in process of stealing values.
26 /// It represents the first value being stolen in the batch. `u16` is used
27 /// in order to distinguish between `head == tail` and `head == tail -
28 /// capacity`.
29 ///
30 /// When both `u16` values are the same, there is no active stealer.
31 ///
32 /// Tracking an in-progress stealer prevents a wrapping scenario.
33 head: AtomicU32,
34
35 /// Only updated by producer thread but read by many threads.
36 tail: AtomicU16,
37
38 /// Elements
39 buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
40 }
41
42 unsafe impl<T> Send for Inner<T> {}
43 unsafe impl<T> Sync for Inner<T> {}
44
45 #[cfg(not(loom))]
46 const LOCAL_QUEUE_CAPACITY: usize = 256;
47
48 // Shrink the size of the local queue when using loom. This shouldn't impact
49 // logic, but allows loom to test more edge cases in a reasonable a mount of
50 // time.
51 #[cfg(loom)]
52 const LOCAL_QUEUE_CAPACITY: usize = 4;
53
54 const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
55
56 // Constructing the fixed size array directly is very awkward. The only way to
57 // do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as
58 // the contents are not Copy. The trick with defining a const doesn't work for
59 // generic types.
make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]>60 fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> {
61 assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY);
62
63 // safety: We check that the length is correct.
64 unsafe { Box::from_raw(Box::into_raw(buffer).cast()) }
65 }
66
67 /// Create a new local run-queue
local<T: 'static>() -> (Steal<T>, Local<T>)68 pub(super) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
69 let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
70
71 for _ in 0..LOCAL_QUEUE_CAPACITY {
72 buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
73 }
74
75 let inner = Arc::new(Inner {
76 head: AtomicU32::new(0),
77 tail: AtomicU16::new(0),
78 buffer: make_fixed_size(buffer.into_boxed_slice()),
79 });
80
81 let local = Local {
82 inner: inner.clone(),
83 };
84
85 let remote = Steal(inner);
86
87 (remote, local)
88 }
89
90 impl<T> Local<T> {
91 /// Returns true if the queue has entries that can be stealed.
is_stealable(&self) -> bool92 pub(super) fn is_stealable(&self) -> bool {
93 !self.inner.is_empty()
94 }
95
96 /// Returns false if there are any entries in the queue
97 ///
98 /// Separate to is_stealable so that refactors of is_stealable to "protect"
99 /// some tasks from stealing won't affect this
has_tasks(&self) -> bool100 pub(super) fn has_tasks(&self) -> bool {
101 !self.inner.is_empty()
102 }
103
104 /// Pushes a task to the back of the local queue, skipping the LIFO slot.
push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>)105 pub(super) fn push_back(&mut self, mut task: task::Notified<T>, inject: &Inject<T>) {
106 let tail = loop {
107 let head = self.inner.head.load(Acquire);
108 let (steal, real) = unpack(head);
109
110 // safety: this is the **only** thread that updates this cell.
111 let tail = unsafe { self.inner.tail.unsync_load() };
112
113 if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as u16 {
114 // There is capacity for the task
115 break tail;
116 } else if steal != real {
117 // Concurrently stealing, this will free up capacity, so only
118 // push the task onto the inject queue
119 inject.push(task);
120 return;
121 } else {
122 // Push the current task and half of the queue into the
123 // inject queue.
124 match self.push_overflow(task, real, tail, inject) {
125 Ok(_) => return,
126 // Lost the race, try again
127 Err(v) => {
128 task = v;
129 }
130 }
131 }
132 };
133
134 // Map the position to a slot index.
135 let idx = tail as usize & MASK;
136
137 self.inner.buffer[idx].with_mut(|ptr| {
138 // Write the task to the slot
139 //
140 // Safety: There is only one producer and the above `if`
141 // condition ensures we don't touch a cell if there is a
142 // value, thus no consumer.
143 unsafe {
144 ptr::write((*ptr).as_mut_ptr(), task);
145 }
146 });
147
148 // Make the task available. Synchronizes with a load in
149 // `steal_into2`.
150 self.inner.tail.store(tail.wrapping_add(1), Release);
151 }
152
153 /// Moves a batch of tasks into the inject queue.
154 ///
155 /// This will temporarily make some of the tasks unavailable to stealers.
156 /// Once `push_overflow` is done, a notification is sent out, so if other
157 /// workers "missed" some of the tasks during a steal, they will get
158 /// another opportunity.
159 #[inline(never)]
push_overflow( &mut self, task: task::Notified<T>, head: u16, tail: u16, inject: &Inject<T>, ) -> Result<(), task::Notified<T>>160 fn push_overflow(
161 &mut self,
162 task: task::Notified<T>,
163 head: u16,
164 tail: u16,
165 inject: &Inject<T>,
166 ) -> Result<(), task::Notified<T>> {
167 /// How many elements are we taking from the local queue.
168 ///
169 /// This is one less than the number of tasks pushed to the inject
170 /// queue as we are also inserting the `task` argument.
171 const NUM_TASKS_TAKEN: u16 = (LOCAL_QUEUE_CAPACITY / 2) as u16;
172
173 assert_eq!(
174 tail.wrapping_sub(head) as usize,
175 LOCAL_QUEUE_CAPACITY,
176 "queue is not full; tail = {}; head = {}",
177 tail,
178 head
179 );
180
181 let prev = pack(head, head);
182
183 // Claim a bunch of tasks
184 //
185 // We are claiming the tasks **before** reading them out of the buffer.
186 // This is safe because only the **current** thread is able to push new
187 // tasks.
188 //
189 // There isn't really any need for memory ordering... Relaxed would
190 // work. This is because all tasks are pushed into the queue from the
191 // current thread (or memory has been acquired if the local queue handle
192 // moved).
193 if self
194 .inner
195 .head
196 .compare_exchange(
197 prev,
198 pack(
199 head.wrapping_add(NUM_TASKS_TAKEN),
200 head.wrapping_add(NUM_TASKS_TAKEN),
201 ),
202 Release,
203 Relaxed,
204 )
205 .is_err()
206 {
207 // We failed to claim the tasks, losing the race. Return out of
208 // this function and try the full `push` routine again. The queue
209 // may not be full anymore.
210 return Err(task);
211 }
212
213 /// An iterator that takes elements out of the run queue.
214 struct BatchTaskIter<'a, T: 'static> {
215 buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
216 head: u32,
217 i: u32,
218 }
219 impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
220 type Item = task::Notified<T>;
221
222 #[inline]
223 fn next(&mut self) -> Option<task::Notified<T>> {
224 if self.i == u32::from(NUM_TASKS_TAKEN) {
225 None
226 } else {
227 let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
228 let slot = &self.buffer[i_idx];
229
230 // safety: Our CAS from before has assumed exclusive ownership
231 // of the task pointers in this range.
232 let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
233
234 self.i += 1;
235 Some(task)
236 }
237 }
238 }
239
240 // safety: The CAS above ensures that no consumer will look at these
241 // values again, and we are the only producer.
242 let batch_iter = BatchTaskIter {
243 buffer: &*self.inner.buffer,
244 head: head as u32,
245 i: 0,
246 };
247 inject.push_batch(batch_iter.chain(std::iter::once(task)));
248
249 Ok(())
250 }
251
252 /// Pops a task from the local queue.
pop(&mut self) -> Option<task::Notified<T>>253 pub(super) fn pop(&mut self) -> Option<task::Notified<T>> {
254 let mut head = self.inner.head.load(Acquire);
255
256 let idx = loop {
257 let (steal, real) = unpack(head);
258
259 // safety: this is the **only** thread that updates this cell.
260 let tail = unsafe { self.inner.tail.unsync_load() };
261
262 if real == tail {
263 // queue is empty
264 return None;
265 }
266
267 let next_real = real.wrapping_add(1);
268
269 // If `steal == real` there are no concurrent stealers. Both `steal`
270 // and `real` are updated.
271 let next = if steal == real {
272 pack(next_real, next_real)
273 } else {
274 assert_ne!(steal, next_real);
275 pack(steal, next_real)
276 };
277
278 // Attempt to claim a task.
279 let res = self
280 .inner
281 .head
282 .compare_exchange(head, next, AcqRel, Acquire);
283
284 match res {
285 Ok(_) => break real as usize & MASK,
286 Err(actual) => head = actual,
287 }
288 };
289
290 Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
291 }
292 }
293
294 impl<T> Steal<T> {
is_empty(&self) -> bool295 pub(super) fn is_empty(&self) -> bool {
296 self.0.is_empty()
297 }
298
299 /// Steals half the tasks from self and place them into `dst`.
steal_into( &self, dst: &mut Local<T>, stats: &mut WorkerStatsBatcher, ) -> Option<task::Notified<T>>300 pub(super) fn steal_into(
301 &self,
302 dst: &mut Local<T>,
303 stats: &mut WorkerStatsBatcher,
304 ) -> Option<task::Notified<T>> {
305 // Safety: the caller is the only thread that mutates `dst.tail` and
306 // holds a mutable reference.
307 let dst_tail = unsafe { dst.inner.tail.unsync_load() };
308
309 // To the caller, `dst` may **look** empty but still have values
310 // contained in the buffer. If another thread is concurrently stealing
311 // from `dst` there may not be enough capacity to steal.
312 let (steal, _) = unpack(dst.inner.head.load(Acquire));
313
314 if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as u16 / 2 {
315 // we *could* try to steal less here, but for simplicity, we're just
316 // going to abort.
317 return None;
318 }
319
320 // Steal the tasks into `dst`'s buffer. This does not yet expose the
321 // tasks in `dst`.
322 let mut n = self.steal_into2(dst, dst_tail);
323 stats.incr_steal_count(n);
324
325 if n == 0 {
326 // No tasks were stolen
327 return None;
328 }
329
330 // We are returning a task here
331 n -= 1;
332
333 let ret_pos = dst_tail.wrapping_add(n);
334 let ret_idx = ret_pos as usize & MASK;
335
336 // safety: the value was written as part of `steal_into2` and not
337 // exposed to stealers, so no other thread can access it.
338 let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
339
340 if n == 0 {
341 // The `dst` queue is empty, but a single task was stolen
342 return Some(ret);
343 }
344
345 // Make the stolen items available to consumers
346 dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
347
348 Some(ret)
349 }
350
351 // Steal tasks from `self`, placing them into `dst`. Returns the number of
352 // tasks that were stolen.
steal_into2(&self, dst: &mut Local<T>, dst_tail: u16) -> u16353 fn steal_into2(&self, dst: &mut Local<T>, dst_tail: u16) -> u16 {
354 let mut prev_packed = self.0.head.load(Acquire);
355 let mut next_packed;
356
357 let n = loop {
358 let (src_head_steal, src_head_real) = unpack(prev_packed);
359 let src_tail = self.0.tail.load(Acquire);
360
361 // If these two do not match, another thread is concurrently
362 // stealing from the queue.
363 if src_head_steal != src_head_real {
364 return 0;
365 }
366
367 // Number of available tasks to steal
368 let n = src_tail.wrapping_sub(src_head_real);
369 let n = n - n / 2;
370
371 if n == 0 {
372 // No tasks available to steal
373 return 0;
374 }
375
376 // Update the real head index to acquire the tasks.
377 let steal_to = src_head_real.wrapping_add(n);
378 assert_ne!(src_head_steal, steal_to);
379 next_packed = pack(src_head_steal, steal_to);
380
381 // Claim all those tasks. This is done by incrementing the "real"
382 // head but not the steal. By doing this, no other thread is able to
383 // steal from this queue until the current thread completes.
384 let res = self
385 .0
386 .head
387 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
388
389 match res {
390 Ok(_) => break n,
391 Err(actual) => prev_packed = actual,
392 }
393 };
394
395 assert!(n <= LOCAL_QUEUE_CAPACITY as u16 / 2, "actual = {}", n);
396
397 let (first, _) = unpack(next_packed);
398
399 // Take all the tasks
400 for i in 0..n {
401 // Compute the positions
402 let src_pos = first.wrapping_add(i);
403 let dst_pos = dst_tail.wrapping_add(i);
404
405 // Map to slots
406 let src_idx = src_pos as usize & MASK;
407 let dst_idx = dst_pos as usize & MASK;
408
409 // Read the task
410 //
411 // safety: We acquired the task with the atomic exchange above.
412 let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
413
414 // Write the task to the new slot
415 //
416 // safety: `dst` queue is empty and we are the only producer to
417 // this queue.
418 dst.inner.buffer[dst_idx]
419 .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
420 }
421
422 let mut prev_packed = next_packed;
423
424 // Update `src_head_steal` to match `src_head_real` signalling that the
425 // stealing routine is complete.
426 loop {
427 let head = unpack(prev_packed).1;
428 next_packed = pack(head, head);
429
430 let res = self
431 .0
432 .head
433 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
434
435 match res {
436 Ok(_) => return n,
437 Err(actual) => {
438 let (actual_steal, actual_real) = unpack(actual);
439
440 assert_ne!(actual_steal, actual_real);
441
442 prev_packed = actual;
443 }
444 }
445 }
446 }
447 }
448
449 impl<T> Clone for Steal<T> {
clone(&self) -> Steal<T>450 fn clone(&self) -> Steal<T> {
451 Steal(self.0.clone())
452 }
453 }
454
455 impl<T> Drop for Local<T> {
drop(&mut self)456 fn drop(&mut self) {
457 if !std::thread::panicking() {
458 assert!(self.pop().is_none(), "queue not empty");
459 }
460 }
461 }
462
463 impl<T> Inner<T> {
is_empty(&self) -> bool464 fn is_empty(&self) -> bool {
465 let (_, head) = unpack(self.head.load(Acquire));
466 let tail = self.tail.load(Acquire);
467
468 head == tail
469 }
470 }
471
472 /// Split the head value into the real head and the index a stealer is working
473 /// on.
unpack(n: u32) -> (u16, u16)474 fn unpack(n: u32) -> (u16, u16) {
475 let real = n & u16::MAX as u32;
476 let steal = n >> 16;
477
478 (steal as u16, real as u16)
479 }
480
481 /// Join the two head values
pack(steal: u16, real: u16) -> u32482 fn pack(steal: u16, real: u16) -> u32 {
483 (real as u32) | ((steal as u32) << 16)
484 }
485
486 #[test]
test_local_queue_capacity()487 fn test_local_queue_capacity() {
488 assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
489 }
490