1 //! Run-queue structures to support a work-stealing scheduler
2
3 use crate::loom::cell::UnsafeCell;
4 use crate::loom::sync::Arc;
5 use crate::runtime::task::{self, Inject};
6 use crate::runtime::MetricsBatch;
7
8 use std::mem::{self, MaybeUninit};
9 use std::ptr;
10 use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release};
11
12 // Use wider integers when possible to increase ABA resilience.
13 //
14 // See issue #5041: <https://github.com/tokio-rs/tokio/issues/5041>.
15 cfg_has_atomic_u64! {
16 type UnsignedShort = u32;
17 type UnsignedLong = u64;
18 type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU32;
19 type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU64;
20 }
21 cfg_not_has_atomic_u64! {
22 type UnsignedShort = u16;
23 type UnsignedLong = u32;
24 type AtomicUnsignedShort = crate::loom::sync::atomic::AtomicU16;
25 type AtomicUnsignedLong = crate::loom::sync::atomic::AtomicU32;
26 }
27
28 /// Producer handle. May only be used from a single thread.
29 pub(crate) struct Local<T: 'static> {
30 inner: Arc<Inner<T>>,
31 }
32
33 /// Consumer handle. May be used from many threads.
34 pub(crate) struct Steal<T: 'static>(Arc<Inner<T>>);
35
36 pub(crate) struct Inner<T: 'static> {
37 /// Concurrently updated by many threads.
38 ///
39 /// Contains two `UnsignedShort` values. The LSB byte is the "real" head of
40 /// the queue. The `UnsignedShort` in the MSB is set by a stealer in process
41 /// of stealing values. It represents the first value being stolen in the
42 /// batch. The `UnsignedShort` indices are intentionally wider than strictly
43 /// required for buffer indexing in order to provide ABA mitigation and make
44 /// it possible to distinguish between full and empty buffers.
45 ///
46 /// When both `UnsignedShort` values are the same, there is no active
47 /// stealer.
48 ///
49 /// Tracking an in-progress stealer prevents a wrapping scenario.
50 head: AtomicUnsignedLong,
51
52 /// Only updated by producer thread but read by many threads.
53 tail: AtomicUnsignedShort,
54
55 /// Elements
56 buffer: Box<[UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY]>,
57 }
58
59 unsafe impl<T> Send for Inner<T> {}
60 unsafe impl<T> Sync for Inner<T> {}
61
62 #[cfg(not(loom))]
63 const LOCAL_QUEUE_CAPACITY: usize = 256;
64
65 // Shrink the size of the local queue when using loom. This shouldn't impact
66 // logic, but allows loom to test more edge cases in a reasonable a mount of
67 // time.
68 #[cfg(loom)]
69 const LOCAL_QUEUE_CAPACITY: usize = 4;
70
71 const MASK: usize = LOCAL_QUEUE_CAPACITY - 1;
72
73 // Constructing the fixed size array directly is very awkward. The only way to
74 // do it is to repeat `UnsafeCell::new(MaybeUninit::uninit())` 256 times, as
75 // the contents are not Copy. The trick with defining a const doesn't work for
76 // generic types.
make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]>77 fn make_fixed_size<T>(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> {
78 assert_eq!(buffer.len(), LOCAL_QUEUE_CAPACITY);
79
80 // safety: We check that the length is correct.
81 unsafe { Box::from_raw(Box::into_raw(buffer).cast()) }
82 }
83
84 /// Create a new local run-queue
local<T: 'static>() -> (Steal<T>, Local<T>)85 pub(crate) fn local<T: 'static>() -> (Steal<T>, Local<T>) {
86 let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY);
87
88 for _ in 0..LOCAL_QUEUE_CAPACITY {
89 buffer.push(UnsafeCell::new(MaybeUninit::uninit()));
90 }
91
92 let inner = Arc::new(Inner {
93 head: AtomicUnsignedLong::new(0),
94 tail: AtomicUnsignedShort::new(0),
95 buffer: make_fixed_size(buffer.into_boxed_slice()),
96 });
97
98 let local = Local {
99 inner: inner.clone(),
100 };
101
102 let remote = Steal(inner);
103
104 (remote, local)
105 }
106
107 impl<T> Local<T> {
108 /// Returns true if the queue has entries that can be stolen.
is_stealable(&self) -> bool109 pub(crate) fn is_stealable(&self) -> bool {
110 !self.inner.is_empty()
111 }
112
113 /// Returns false if there are any entries in the queue
114 ///
115 /// Separate to is_stealable so that refactors of is_stealable to "protect"
116 /// some tasks from stealing won't affect this
has_tasks(&self) -> bool117 pub(crate) fn has_tasks(&self) -> bool {
118 !self.inner.is_empty()
119 }
120
121 /// Pushes a task to the back of the local queue, skipping the LIFO slot.
push_back( &mut self, mut task: task::Notified<T>, inject: &Inject<T>, metrics: &mut MetricsBatch, )122 pub(crate) fn push_back(
123 &mut self,
124 mut task: task::Notified<T>,
125 inject: &Inject<T>,
126 metrics: &mut MetricsBatch,
127 ) {
128 let tail = loop {
129 let head = self.inner.head.load(Acquire);
130 let (steal, real) = unpack(head);
131
132 // safety: this is the **only** thread that updates this cell.
133 let tail = unsafe { self.inner.tail.unsync_load() };
134
135 if tail.wrapping_sub(steal) < LOCAL_QUEUE_CAPACITY as UnsignedShort {
136 // There is capacity for the task
137 break tail;
138 } else if steal != real {
139 // Concurrently stealing, this will free up capacity, so only
140 // push the task onto the inject queue
141 inject.push(task);
142 return;
143 } else {
144 // Push the current task and half of the queue into the
145 // inject queue.
146 match self.push_overflow(task, real, tail, inject, metrics) {
147 Ok(_) => return,
148 // Lost the race, try again
149 Err(v) => {
150 task = v;
151 }
152 }
153 }
154 };
155
156 // Map the position to a slot index.
157 let idx = tail as usize & MASK;
158
159 self.inner.buffer[idx].with_mut(|ptr| {
160 // Write the task to the slot
161 //
162 // Safety: There is only one producer and the above `if`
163 // condition ensures we don't touch a cell if there is a
164 // value, thus no consumer.
165 unsafe {
166 ptr::write((*ptr).as_mut_ptr(), task);
167 }
168 });
169
170 // Make the task available. Synchronizes with a load in
171 // `steal_into2`.
172 self.inner.tail.store(tail.wrapping_add(1), Release);
173 }
174
175 /// Moves a batch of tasks into the inject queue.
176 ///
177 /// This will temporarily make some of the tasks unavailable to stealers.
178 /// Once `push_overflow` is done, a notification is sent out, so if other
179 /// workers "missed" some of the tasks during a steal, they will get
180 /// another opportunity.
181 #[inline(never)]
push_overflow( &mut self, task: task::Notified<T>, head: UnsignedShort, tail: UnsignedShort, inject: &Inject<T>, metrics: &mut MetricsBatch, ) -> Result<(), task::Notified<T>>182 fn push_overflow(
183 &mut self,
184 task: task::Notified<T>,
185 head: UnsignedShort,
186 tail: UnsignedShort,
187 inject: &Inject<T>,
188 metrics: &mut MetricsBatch,
189 ) -> Result<(), task::Notified<T>> {
190 /// How many elements are we taking from the local queue.
191 ///
192 /// This is one less than the number of tasks pushed to the inject
193 /// queue as we are also inserting the `task` argument.
194 const NUM_TASKS_TAKEN: UnsignedShort = (LOCAL_QUEUE_CAPACITY / 2) as UnsignedShort;
195
196 assert_eq!(
197 tail.wrapping_sub(head) as usize,
198 LOCAL_QUEUE_CAPACITY,
199 "queue is not full; tail = {}; head = {}",
200 tail,
201 head
202 );
203
204 let prev = pack(head, head);
205
206 // Claim a bunch of tasks
207 //
208 // We are claiming the tasks **before** reading them out of the buffer.
209 // This is safe because only the **current** thread is able to push new
210 // tasks.
211 //
212 // There isn't really any need for memory ordering... Relaxed would
213 // work. This is because all tasks are pushed into the queue from the
214 // current thread (or memory has been acquired if the local queue handle
215 // moved).
216 if self
217 .inner
218 .head
219 .compare_exchange(
220 prev,
221 pack(
222 head.wrapping_add(NUM_TASKS_TAKEN),
223 head.wrapping_add(NUM_TASKS_TAKEN),
224 ),
225 Release,
226 Relaxed,
227 )
228 .is_err()
229 {
230 // We failed to claim the tasks, losing the race. Return out of
231 // this function and try the full `push` routine again. The queue
232 // may not be full anymore.
233 return Err(task);
234 }
235
236 /// An iterator that takes elements out of the run queue.
237 struct BatchTaskIter<'a, T: 'static> {
238 buffer: &'a [UnsafeCell<MaybeUninit<task::Notified<T>>>; LOCAL_QUEUE_CAPACITY],
239 head: UnsignedLong,
240 i: UnsignedLong,
241 }
242 impl<'a, T: 'static> Iterator for BatchTaskIter<'a, T> {
243 type Item = task::Notified<T>;
244
245 #[inline]
246 fn next(&mut self) -> Option<task::Notified<T>> {
247 if self.i == UnsignedLong::from(NUM_TASKS_TAKEN) {
248 None
249 } else {
250 let i_idx = self.i.wrapping_add(self.head) as usize & MASK;
251 let slot = &self.buffer[i_idx];
252
253 // safety: Our CAS from before has assumed exclusive ownership
254 // of the task pointers in this range.
255 let task = slot.with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
256
257 self.i += 1;
258 Some(task)
259 }
260 }
261 }
262
263 // safety: The CAS above ensures that no consumer will look at these
264 // values again, and we are the only producer.
265 let batch_iter = BatchTaskIter {
266 buffer: &self.inner.buffer,
267 head: head as UnsignedLong,
268 i: 0,
269 };
270 inject.push_batch(batch_iter.chain(std::iter::once(task)));
271
272 // Add 1 to factor in the task currently being scheduled.
273 metrics.incr_overflow_count();
274
275 Ok(())
276 }
277
278 /// Pops a task from the local queue.
pop(&mut self) -> Option<task::Notified<T>>279 pub(crate) fn pop(&mut self) -> Option<task::Notified<T>> {
280 let mut head = self.inner.head.load(Acquire);
281
282 let idx = loop {
283 let (steal, real) = unpack(head);
284
285 // safety: this is the **only** thread that updates this cell.
286 let tail = unsafe { self.inner.tail.unsync_load() };
287
288 if real == tail {
289 // queue is empty
290 return None;
291 }
292
293 let next_real = real.wrapping_add(1);
294
295 // If `steal == real` there are no concurrent stealers. Both `steal`
296 // and `real` are updated.
297 let next = if steal == real {
298 pack(next_real, next_real)
299 } else {
300 assert_ne!(steal, next_real);
301 pack(steal, next_real)
302 };
303
304 // Attempt to claim a task.
305 let res = self
306 .inner
307 .head
308 .compare_exchange(head, next, AcqRel, Acquire);
309
310 match res {
311 Ok(_) => break real as usize & MASK,
312 Err(actual) => head = actual,
313 }
314 };
315
316 Some(self.inner.buffer[idx].with(|ptr| unsafe { ptr::read(ptr).assume_init() }))
317 }
318 }
319
320 impl<T> Steal<T> {
is_empty(&self) -> bool321 pub(crate) fn is_empty(&self) -> bool {
322 self.0.is_empty()
323 }
324
325 /// Steals half the tasks from self and place them into `dst`.
steal_into( &self, dst: &mut Local<T>, dst_metrics: &mut MetricsBatch, ) -> Option<task::Notified<T>>326 pub(crate) fn steal_into(
327 &self,
328 dst: &mut Local<T>,
329 dst_metrics: &mut MetricsBatch,
330 ) -> Option<task::Notified<T>> {
331 // Safety: the caller is the only thread that mutates `dst.tail` and
332 // holds a mutable reference.
333 let dst_tail = unsafe { dst.inner.tail.unsync_load() };
334
335 // To the caller, `dst` may **look** empty but still have values
336 // contained in the buffer. If another thread is concurrently stealing
337 // from `dst` there may not be enough capacity to steal.
338 let (steal, _) = unpack(dst.inner.head.load(Acquire));
339
340 if dst_tail.wrapping_sub(steal) > LOCAL_QUEUE_CAPACITY as UnsignedShort / 2 {
341 // we *could* try to steal less here, but for simplicity, we're just
342 // going to abort.
343 return None;
344 }
345
346 // Steal the tasks into `dst`'s buffer. This does not yet expose the
347 // tasks in `dst`.
348 let mut n = self.steal_into2(dst, dst_tail);
349
350 if n == 0 {
351 // No tasks were stolen
352 return None;
353 }
354
355 dst_metrics.incr_steal_count(n as u16);
356 dst_metrics.incr_steal_operations();
357
358 // We are returning a task here
359 n -= 1;
360
361 let ret_pos = dst_tail.wrapping_add(n);
362 let ret_idx = ret_pos as usize & MASK;
363
364 // safety: the value was written as part of `steal_into2` and not
365 // exposed to stealers, so no other thread can access it.
366 let ret = dst.inner.buffer[ret_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
367
368 if n == 0 {
369 // The `dst` queue is empty, but a single task was stolen
370 return Some(ret);
371 }
372
373 // Make the stolen items available to consumers
374 dst.inner.tail.store(dst_tail.wrapping_add(n), Release);
375
376 Some(ret)
377 }
378
379 // Steal tasks from `self`, placing them into `dst`. Returns the number of
380 // tasks that were stolen.
steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort381 fn steal_into2(&self, dst: &mut Local<T>, dst_tail: UnsignedShort) -> UnsignedShort {
382 let mut prev_packed = self.0.head.load(Acquire);
383 let mut next_packed;
384
385 let n = loop {
386 let (src_head_steal, src_head_real) = unpack(prev_packed);
387 let src_tail = self.0.tail.load(Acquire);
388
389 // If these two do not match, another thread is concurrently
390 // stealing from the queue.
391 if src_head_steal != src_head_real {
392 return 0;
393 }
394
395 // Number of available tasks to steal
396 let n = src_tail.wrapping_sub(src_head_real);
397 let n = n - n / 2;
398
399 if n == 0 {
400 // No tasks available to steal
401 return 0;
402 }
403
404 // Update the real head index to acquire the tasks.
405 let steal_to = src_head_real.wrapping_add(n);
406 assert_ne!(src_head_steal, steal_to);
407 next_packed = pack(src_head_steal, steal_to);
408
409 // Claim all those tasks. This is done by incrementing the "real"
410 // head but not the steal. By doing this, no other thread is able to
411 // steal from this queue until the current thread completes.
412 let res = self
413 .0
414 .head
415 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
416
417 match res {
418 Ok(_) => break n,
419 Err(actual) => prev_packed = actual,
420 }
421 };
422
423 assert!(
424 n <= LOCAL_QUEUE_CAPACITY as UnsignedShort / 2,
425 "actual = {}",
426 n
427 );
428
429 let (first, _) = unpack(next_packed);
430
431 // Take all the tasks
432 for i in 0..n {
433 // Compute the positions
434 let src_pos = first.wrapping_add(i);
435 let dst_pos = dst_tail.wrapping_add(i);
436
437 // Map to slots
438 let src_idx = src_pos as usize & MASK;
439 let dst_idx = dst_pos as usize & MASK;
440
441 // Read the task
442 //
443 // safety: We acquired the task with the atomic exchange above.
444 let task = self.0.buffer[src_idx].with(|ptr| unsafe { ptr::read((*ptr).as_ptr()) });
445
446 // Write the task to the new slot
447 //
448 // safety: `dst` queue is empty and we are the only producer to
449 // this queue.
450 dst.inner.buffer[dst_idx]
451 .with_mut(|ptr| unsafe { ptr::write((*ptr).as_mut_ptr(), task) });
452 }
453
454 let mut prev_packed = next_packed;
455
456 // Update `src_head_steal` to match `src_head_real` signalling that the
457 // stealing routine is complete.
458 loop {
459 let head = unpack(prev_packed).1;
460 next_packed = pack(head, head);
461
462 let res = self
463 .0
464 .head
465 .compare_exchange(prev_packed, next_packed, AcqRel, Acquire);
466
467 match res {
468 Ok(_) => return n,
469 Err(actual) => {
470 let (actual_steal, actual_real) = unpack(actual);
471
472 assert_ne!(actual_steal, actual_real);
473
474 prev_packed = actual;
475 }
476 }
477 }
478 }
479 }
480
481 cfg_metrics! {
482 impl<T> Steal<T> {
483 pub(crate) fn len(&self) -> usize {
484 self.0.len() as _
485 }
486 }
487 }
488
489 impl<T> Clone for Steal<T> {
clone(&self) -> Steal<T>490 fn clone(&self) -> Steal<T> {
491 Steal(self.0.clone())
492 }
493 }
494
495 impl<T> Drop for Local<T> {
drop(&mut self)496 fn drop(&mut self) {
497 if !std::thread::panicking() {
498 assert!(self.pop().is_none(), "queue not empty");
499 }
500 }
501 }
502
503 impl<T> Inner<T> {
len(&self) -> UnsignedShort504 fn len(&self) -> UnsignedShort {
505 let (_, head) = unpack(self.head.load(Acquire));
506 let tail = self.tail.load(Acquire);
507
508 tail.wrapping_sub(head)
509 }
510
is_empty(&self) -> bool511 fn is_empty(&self) -> bool {
512 self.len() == 0
513 }
514 }
515
516 /// Split the head value into the real head and the index a stealer is working
517 /// on.
unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort)518 fn unpack(n: UnsignedLong) -> (UnsignedShort, UnsignedShort) {
519 let real = n & UnsignedShort::MAX as UnsignedLong;
520 let steal = n >> (mem::size_of::<UnsignedShort>() * 8);
521
522 (steal as UnsignedShort, real as UnsignedShort)
523 }
524
525 /// Join the two head values
pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong526 fn pack(steal: UnsignedShort, real: UnsignedShort) -> UnsignedLong {
527 (real as UnsignedLong) | ((steal as UnsignedLong) << (mem::size_of::<UnsignedShort>() * 8))
528 }
529
530 #[test]
test_local_queue_capacity()531 fn test_local_queue_capacity() {
532 assert!(LOCAL_QUEUE_CAPACITY - 1 <= u8::MAX as usize);
533 }
534