• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // There's a lot of scary concurrent code in this module, but it is copied from
2 // `std::sync::Once` with two changes:
3 //   * no poisoning
4 //   * init function can fail
5 
6 use std::{
7     cell::{Cell, UnsafeCell},
8     marker::PhantomData,
9     panic::{RefUnwindSafe, UnwindSafe},
10     sync::atomic::{AtomicBool, AtomicPtr, Ordering},
11     thread::{self, Thread},
12 };
13 
14 #[derive(Debug)]
15 pub(crate) struct OnceCell<T> {
16     // This `queue` field is the core of the implementation. It encodes two
17     // pieces of information:
18     //
19     // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`)
20     // * Linked list of threads waiting for the current cell.
21     //
22     // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states
23     // allow waiters.
24     queue: AtomicPtr<Waiter>,
25     _marker: PhantomData<*mut Waiter>,
26     value: UnsafeCell<Option<T>>,
27 }
28 
29 // Why do we need `T: Send`?
30 // Thread A creates a `OnceCell` and shares it with
31 // scoped thread B, which fills the cell, which is
32 // then destroyed by A. That is, destructor observes
33 // a sent value.
34 unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
35 unsafe impl<T: Send> Send for OnceCell<T> {}
36 
37 impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
38 impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
39 
40 impl<T> OnceCell<T> {
new() -> OnceCell<T>41     pub(crate) const fn new() -> OnceCell<T> {
42         OnceCell {
43             queue: AtomicPtr::new(INCOMPLETE_PTR),
44             _marker: PhantomData,
45             value: UnsafeCell::new(None),
46         }
47     }
48 
with_value(value: T) -> OnceCell<T>49     pub(crate) const fn with_value(value: T) -> OnceCell<T> {
50         OnceCell {
51             queue: AtomicPtr::new(COMPLETE_PTR),
52             _marker: PhantomData,
53             value: UnsafeCell::new(Some(value)),
54         }
55     }
56 
57     /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst).
58     #[inline]
is_initialized(&self) -> bool59     pub(crate) fn is_initialized(&self) -> bool {
60         // An `Acquire` load is enough because that makes all the initialization
61         // operations visible to us, and, this being a fast path, weaker
62         // ordering helps with performance. This `Acquire` synchronizes with
63         // `SeqCst` operations on the slow path.
64         self.queue.load(Ordering::Acquire) == COMPLETE_PTR
65     }
66 
67     /// Safety: synchronizes with store to value via SeqCst read from state,
68     /// writes value only once because we never get to INCOMPLETE state after a
69     /// successful write.
70     #[cold]
initialize<F, E>(&self, f: F) -> Result<(), E> where F: FnOnce() -> Result<T, E>,71     pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E>
72     where
73         F: FnOnce() -> Result<T, E>,
74     {
75         let mut f = Some(f);
76         let mut res: Result<(), E> = Ok(());
77         let slot: *mut Option<T> = self.value.get();
78         initialize_or_wait(
79             &self.queue,
80             Some(&mut || {
81                 let f = unsafe { crate::unwrap_unchecked(f.take()) };
82                 match f() {
83                     Ok(value) => {
84                         unsafe { *slot = Some(value) };
85                         true
86                     }
87                     Err(err) => {
88                         res = Err(err);
89                         false
90                     }
91                 }
92             }),
93         );
94         res
95     }
96 
97     #[cold]
wait(&self)98     pub(crate) fn wait(&self) {
99         initialize_or_wait(&self.queue, None);
100     }
101 
102     /// Get the reference to the underlying value, without checking if the cell
103     /// is initialized.
104     ///
105     /// # Safety
106     ///
107     /// Caller must ensure that the cell is in initialized state, and that
108     /// the contents are acquired by (synchronized to) this thread.
get_unchecked(&self) -> &T109     pub(crate) unsafe fn get_unchecked(&self) -> &T {
110         debug_assert!(self.is_initialized());
111         let slot = &*self.value.get();
112         crate::unwrap_unchecked(slot.as_ref())
113     }
114 
115     /// Gets the mutable reference to the underlying value.
116     /// Returns `None` if the cell is empty.
get_mut(&mut self) -> Option<&mut T>117     pub(crate) fn get_mut(&mut self) -> Option<&mut T> {
118         // Safe b/c we have a unique access.
119         unsafe { &mut *self.value.get() }.as_mut()
120     }
121 
122     /// Consumes this `OnceCell`, returning the wrapped value.
123     /// Returns `None` if the cell was empty.
124     #[inline]
into_inner(self) -> Option<T>125     pub(crate) fn into_inner(self) -> Option<T> {
126         // Because `into_inner` takes `self` by value, the compiler statically
127         // verifies that it is not currently borrowed.
128         // So, it is safe to move out `Option<T>`.
129         self.value.into_inner()
130     }
131 }
132 
133 // Three states that a OnceCell can be in, encoded into the lower bits of `queue` in
134 // the OnceCell structure.
135 const INCOMPLETE: usize = 0x0;
136 const RUNNING: usize = 0x1;
137 const COMPLETE: usize = 0x2;
138 const INCOMPLETE_PTR: *mut Waiter = INCOMPLETE as *mut Waiter;
139 const COMPLETE_PTR: *mut Waiter = COMPLETE as *mut Waiter;
140 
141 // Mask to learn about the state. All other bits are the queue of waiters if
142 // this is in the RUNNING state.
143 const STATE_MASK: usize = 0x3;
144 
145 /// Representation of a node in the linked list of waiters in the RUNNING state.
146 /// A waiters is stored on the stack of the waiting threads.
147 #[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
148 struct Waiter {
149     thread: Cell<Option<Thread>>,
150     signaled: AtomicBool,
151     next: *mut Waiter,
152 }
153 
154 /// Drains and notifies the queue of waiters on drop.
155 struct Guard<'a> {
156     queue: &'a AtomicPtr<Waiter>,
157     new_queue: *mut Waiter,
158 }
159 
160 impl Drop for Guard<'_> {
drop(&mut self)161     fn drop(&mut self) {
162         let queue = self.queue.swap(self.new_queue, Ordering::AcqRel);
163 
164         let state = strict::addr(queue) & STATE_MASK;
165         assert_eq!(state, RUNNING);
166 
167         unsafe {
168             let mut waiter = strict::map_addr(queue, |q| q & !STATE_MASK);
169             while !waiter.is_null() {
170                 let next = (*waiter).next;
171                 let thread = (*waiter).thread.take().unwrap();
172                 (*waiter).signaled.store(true, Ordering::Release);
173                 waiter = next;
174                 thread.unpark();
175             }
176         }
177     }
178 }
179 
180 // Corresponds to `std::sync::Once::call_inner`.
181 //
182 // Originally copied from std, but since modified to remove poisoning and to
183 // support wait.
184 //
185 // Note: this is intentionally monomorphic
186 #[inline(never)]
initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>)187 fn initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>) {
188     let mut curr_queue = queue.load(Ordering::Acquire);
189 
190     loop {
191         let curr_state = strict::addr(curr_queue) & STATE_MASK;
192         match (curr_state, &mut init) {
193             (COMPLETE, _) => return,
194             (INCOMPLETE, Some(init)) => {
195                 let exchange = queue.compare_exchange(
196                     curr_queue,
197                     strict::map_addr(curr_queue, |q| (q & !STATE_MASK) | RUNNING),
198                     Ordering::Acquire,
199                     Ordering::Acquire,
200                 );
201                 if let Err(new_queue) = exchange {
202                     curr_queue = new_queue;
203                     continue;
204                 }
205                 let mut guard = Guard { queue, new_queue: INCOMPLETE_PTR };
206                 if init() {
207                     guard.new_queue = COMPLETE_PTR;
208                 }
209                 return;
210             }
211             (INCOMPLETE, None) | (RUNNING, _) => {
212                 wait(&queue, curr_queue);
213                 curr_queue = queue.load(Ordering::Acquire);
214             }
215             _ => debug_assert!(false),
216         }
217     }
218 }
219 
wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter)220 fn wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter) {
221     let curr_state = strict::addr(curr_queue) & STATE_MASK;
222     loop {
223         let node = Waiter {
224             thread: Cell::new(Some(thread::current())),
225             signaled: AtomicBool::new(false),
226             next: strict::map_addr(curr_queue, |q| q & !STATE_MASK),
227         };
228         let me = &node as *const Waiter as *mut Waiter;
229 
230         let exchange = queue.compare_exchange(
231             curr_queue,
232             strict::map_addr(me, |q| q | curr_state),
233             Ordering::Release,
234             Ordering::Relaxed,
235         );
236         if let Err(new_queue) = exchange {
237             if strict::addr(new_queue) & STATE_MASK != curr_state {
238                 return;
239             }
240             curr_queue = new_queue;
241             continue;
242         }
243 
244         while !node.signaled.load(Ordering::Acquire) {
245             thread::park();
246         }
247         break;
248     }
249 }
250 
251 // Polyfill of strict provenance from https://crates.io/crates/sptr.
252 //
253 // Use free-standing function rather than a trait to keep things simple and
254 // avoid any potential conflicts with future stabile std API.
255 mod strict {
256     #[must_use]
257     #[inline]
addr<T>(ptr: *mut T) -> usize where T: Sized,258     pub(crate) fn addr<T>(ptr: *mut T) -> usize
259     where
260         T: Sized,
261     {
262         // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
263         // SAFETY: Pointer-to-integer transmutes are valid (if you are okay with losing the
264         // provenance).
265         unsafe { core::mem::transmute(ptr) }
266     }
267 
268     #[must_use]
269     #[inline]
with_addr<T>(ptr: *mut T, addr: usize) -> *mut T where T: Sized,270     pub(crate) fn with_addr<T>(ptr: *mut T, addr: usize) -> *mut T
271     where
272         T: Sized,
273     {
274         // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
275         //
276         // In the mean-time, this operation is defined to be "as if" it was
277         // a wrapping_offset, so we can emulate it as such. This should properly
278         // restore pointer provenance even under today's compiler.
279         let self_addr = self::addr(ptr) as isize;
280         let dest_addr = addr as isize;
281         let offset = dest_addr.wrapping_sub(self_addr);
282 
283         // This is the canonical desugarring of this operation,
284         // but `pointer::cast` was only stabilized in 1.38.
285         // self.cast::<u8>().wrapping_offset(offset).cast::<T>()
286         (ptr as *mut u8).wrapping_offset(offset) as *mut T
287     }
288 
289     #[must_use]
290     #[inline]
map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T where T: Sized,291     pub(crate) fn map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T
292     where
293         T: Sized,
294     {
295         self::with_addr(ptr, f(addr(ptr)))
296     }
297 }
298 
299 // These test are snatched from std as well.
300 #[cfg(test)]
301 mod tests {
302     use std::panic;
303     use std::{sync::mpsc::channel, thread};
304 
305     use super::OnceCell;
306 
307     impl<T> OnceCell<T> {
init(&self, f: impl FnOnce() -> T)308         fn init(&self, f: impl FnOnce() -> T) {
309             enum Void {}
310             let _ = self.initialize(|| Ok::<T, Void>(f()));
311         }
312     }
313 
314     #[test]
smoke_once()315     fn smoke_once() {
316         static O: OnceCell<()> = OnceCell::new();
317         let mut a = 0;
318         O.init(|| a += 1);
319         assert_eq!(a, 1);
320         O.init(|| a += 1);
321         assert_eq!(a, 1);
322     }
323 
324     #[test]
stampede_once()325     fn stampede_once() {
326         static O: OnceCell<()> = OnceCell::new();
327         static mut RUN: bool = false;
328 
329         let (tx, rx) = channel();
330         for _ in 0..10 {
331             let tx = tx.clone();
332             thread::spawn(move || {
333                 for _ in 0..4 {
334                     thread::yield_now()
335                 }
336                 unsafe {
337                     O.init(|| {
338                         assert!(!RUN);
339                         RUN = true;
340                     });
341                     assert!(RUN);
342                 }
343                 tx.send(()).unwrap();
344             });
345         }
346 
347         unsafe {
348             O.init(|| {
349                 assert!(!RUN);
350                 RUN = true;
351             });
352             assert!(RUN);
353         }
354 
355         for _ in 0..10 {
356             rx.recv().unwrap();
357         }
358     }
359 
360     #[test]
poison_bad()361     fn poison_bad() {
362         static O: OnceCell<()> = OnceCell::new();
363 
364         // poison the once
365         let t = panic::catch_unwind(|| {
366             O.init(|| panic!());
367         });
368         assert!(t.is_err());
369 
370         // we can subvert poisoning, however
371         let mut called = false;
372         O.init(|| {
373             called = true;
374         });
375         assert!(called);
376 
377         // once any success happens, we stop propagating the poison
378         O.init(|| {});
379     }
380 
381     #[test]
wait_for_force_to_finish()382     fn wait_for_force_to_finish() {
383         static O: OnceCell<()> = OnceCell::new();
384 
385         // poison the once
386         let t = panic::catch_unwind(|| {
387             O.init(|| panic!());
388         });
389         assert!(t.is_err());
390 
391         // make sure someone's waiting inside the once via a force
392         let (tx1, rx1) = channel();
393         let (tx2, rx2) = channel();
394         let t1 = thread::spawn(move || {
395             O.init(|| {
396                 tx1.send(()).unwrap();
397                 rx2.recv().unwrap();
398             });
399         });
400 
401         rx1.recv().unwrap();
402 
403         // put another waiter on the once
404         let t2 = thread::spawn(|| {
405             let mut called = false;
406             O.init(|| {
407                 called = true;
408             });
409             assert!(!called);
410         });
411 
412         tx2.send(()).unwrap();
413 
414         assert!(t1.join().is_ok());
415         assert!(t2.join().is_ok());
416     }
417 
418     #[test]
419     #[cfg(target_pointer_width = "64")]
test_size()420     fn test_size() {
421         use std::mem::size_of;
422 
423         assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>());
424     }
425 }
426