1 // There's a lot of scary concurrent code in this module, but it is copied from
2 // `std::sync::Once` with two changes:
3 // * no poisoning
4 // * init function can fail
5
6 use std::{
7 cell::{Cell, UnsafeCell},
8 marker::PhantomData,
9 panic::{RefUnwindSafe, UnwindSafe},
10 sync::atomic::{AtomicBool, AtomicPtr, Ordering},
11 thread::{self, Thread},
12 };
13
14 #[derive(Debug)]
15 pub(crate) struct OnceCell<T> {
16 // This `queue` field is the core of the implementation. It encodes two
17 // pieces of information:
18 //
19 // * The current state of the cell (`INCOMPLETE`, `RUNNING`, `COMPLETE`)
20 // * Linked list of threads waiting for the current cell.
21 //
22 // State is encoded in two low bits. Only `INCOMPLETE` and `RUNNING` states
23 // allow waiters.
24 queue: AtomicPtr<Waiter>,
25 _marker: PhantomData<*mut Waiter>,
26 value: UnsafeCell<Option<T>>,
27 }
28
29 // Why do we need `T: Send`?
30 // Thread A creates a `OnceCell` and shares it with
31 // scoped thread B, which fills the cell, which is
32 // then destroyed by A. That is, destructor observes
33 // a sent value.
34 unsafe impl<T: Sync + Send> Sync for OnceCell<T> {}
35 unsafe impl<T: Send> Send for OnceCell<T> {}
36
37 impl<T: RefUnwindSafe + UnwindSafe> RefUnwindSafe for OnceCell<T> {}
38 impl<T: UnwindSafe> UnwindSafe for OnceCell<T> {}
39
40 impl<T> OnceCell<T> {
new() -> OnceCell<T>41 pub(crate) const fn new() -> OnceCell<T> {
42 OnceCell {
43 queue: AtomicPtr::new(INCOMPLETE_PTR),
44 _marker: PhantomData,
45 value: UnsafeCell::new(None),
46 }
47 }
48
with_value(value: T) -> OnceCell<T>49 pub(crate) const fn with_value(value: T) -> OnceCell<T> {
50 OnceCell {
51 queue: AtomicPtr::new(COMPLETE_PTR),
52 _marker: PhantomData,
53 value: UnsafeCell::new(Some(value)),
54 }
55 }
56
57 /// Safety: synchronizes with store to value via Release/(Acquire|SeqCst).
58 #[inline]
is_initialized(&self) -> bool59 pub(crate) fn is_initialized(&self) -> bool {
60 // An `Acquire` load is enough because that makes all the initialization
61 // operations visible to us, and, this being a fast path, weaker
62 // ordering helps with performance. This `Acquire` synchronizes with
63 // `SeqCst` operations on the slow path.
64 self.queue.load(Ordering::Acquire) == COMPLETE_PTR
65 }
66
67 /// Safety: synchronizes with store to value via SeqCst read from state,
68 /// writes value only once because we never get to INCOMPLETE state after a
69 /// successful write.
70 #[cold]
initialize<F, E>(&self, f: F) -> Result<(), E> where F: FnOnce() -> Result<T, E>,71 pub(crate) fn initialize<F, E>(&self, f: F) -> Result<(), E>
72 where
73 F: FnOnce() -> Result<T, E>,
74 {
75 let mut f = Some(f);
76 let mut res: Result<(), E> = Ok(());
77 let slot: *mut Option<T> = self.value.get();
78 initialize_or_wait(
79 &self.queue,
80 Some(&mut || {
81 let f = unsafe { crate::unwrap_unchecked(f.take()) };
82 match f() {
83 Ok(value) => {
84 unsafe { *slot = Some(value) };
85 true
86 }
87 Err(err) => {
88 res = Err(err);
89 false
90 }
91 }
92 }),
93 );
94 res
95 }
96
97 #[cold]
wait(&self)98 pub(crate) fn wait(&self) {
99 initialize_or_wait(&self.queue, None);
100 }
101
102 /// Get the reference to the underlying value, without checking if the cell
103 /// is initialized.
104 ///
105 /// # Safety
106 ///
107 /// Caller must ensure that the cell is in initialized state, and that
108 /// the contents are acquired by (synchronized to) this thread.
get_unchecked(&self) -> &T109 pub(crate) unsafe fn get_unchecked(&self) -> &T {
110 debug_assert!(self.is_initialized());
111 let slot = &*self.value.get();
112 crate::unwrap_unchecked(slot.as_ref())
113 }
114
115 /// Gets the mutable reference to the underlying value.
116 /// Returns `None` if the cell is empty.
get_mut(&mut self) -> Option<&mut T>117 pub(crate) fn get_mut(&mut self) -> Option<&mut T> {
118 // Safe b/c we have a unique access.
119 unsafe { &mut *self.value.get() }.as_mut()
120 }
121
122 /// Consumes this `OnceCell`, returning the wrapped value.
123 /// Returns `None` if the cell was empty.
124 #[inline]
into_inner(self) -> Option<T>125 pub(crate) fn into_inner(self) -> Option<T> {
126 // Because `into_inner` takes `self` by value, the compiler statically
127 // verifies that it is not currently borrowed.
128 // So, it is safe to move out `Option<T>`.
129 self.value.into_inner()
130 }
131 }
132
133 // Three states that a OnceCell can be in, encoded into the lower bits of `queue` in
134 // the OnceCell structure.
135 const INCOMPLETE: usize = 0x0;
136 const RUNNING: usize = 0x1;
137 const COMPLETE: usize = 0x2;
138 const INCOMPLETE_PTR: *mut Waiter = INCOMPLETE as *mut Waiter;
139 const COMPLETE_PTR: *mut Waiter = COMPLETE as *mut Waiter;
140
141 // Mask to learn about the state. All other bits are the queue of waiters if
142 // this is in the RUNNING state.
143 const STATE_MASK: usize = 0x3;
144
145 /// Representation of a node in the linked list of waiters in the RUNNING state.
146 /// A waiters is stored on the stack of the waiting threads.
147 #[repr(align(4))] // Ensure the two lower bits are free to use as state bits.
148 struct Waiter {
149 thread: Cell<Option<Thread>>,
150 signaled: AtomicBool,
151 next: *mut Waiter,
152 }
153
154 /// Drains and notifies the queue of waiters on drop.
155 struct Guard<'a> {
156 queue: &'a AtomicPtr<Waiter>,
157 new_queue: *mut Waiter,
158 }
159
160 impl Drop for Guard<'_> {
drop(&mut self)161 fn drop(&mut self) {
162 let queue = self.queue.swap(self.new_queue, Ordering::AcqRel);
163
164 let state = strict::addr(queue) & STATE_MASK;
165 assert_eq!(state, RUNNING);
166
167 unsafe {
168 let mut waiter = strict::map_addr(queue, |q| q & !STATE_MASK);
169 while !waiter.is_null() {
170 let next = (*waiter).next;
171 let thread = (*waiter).thread.take().unwrap();
172 (*waiter).signaled.store(true, Ordering::Release);
173 waiter = next;
174 thread.unpark();
175 }
176 }
177 }
178 }
179
180 // Corresponds to `std::sync::Once::call_inner`.
181 //
182 // Originally copied from std, but since modified to remove poisoning and to
183 // support wait.
184 //
185 // Note: this is intentionally monomorphic
186 #[inline(never)]
initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>)187 fn initialize_or_wait(queue: &AtomicPtr<Waiter>, mut init: Option<&mut dyn FnMut() -> bool>) {
188 let mut curr_queue = queue.load(Ordering::Acquire);
189
190 loop {
191 let curr_state = strict::addr(curr_queue) & STATE_MASK;
192 match (curr_state, &mut init) {
193 (COMPLETE, _) => return,
194 (INCOMPLETE, Some(init)) => {
195 let exchange = queue.compare_exchange(
196 curr_queue,
197 strict::map_addr(curr_queue, |q| (q & !STATE_MASK) | RUNNING),
198 Ordering::Acquire,
199 Ordering::Acquire,
200 );
201 if let Err(new_queue) = exchange {
202 curr_queue = new_queue;
203 continue;
204 }
205 let mut guard = Guard { queue, new_queue: INCOMPLETE_PTR };
206 if init() {
207 guard.new_queue = COMPLETE_PTR;
208 }
209 return;
210 }
211 (INCOMPLETE, None) | (RUNNING, _) => {
212 wait(&queue, curr_queue);
213 curr_queue = queue.load(Ordering::Acquire);
214 }
215 _ => debug_assert!(false),
216 }
217 }
218 }
219
wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter)220 fn wait(queue: &AtomicPtr<Waiter>, mut curr_queue: *mut Waiter) {
221 let curr_state = strict::addr(curr_queue) & STATE_MASK;
222 loop {
223 let node = Waiter {
224 thread: Cell::new(Some(thread::current())),
225 signaled: AtomicBool::new(false),
226 next: strict::map_addr(curr_queue, |q| q & !STATE_MASK),
227 };
228 let me = &node as *const Waiter as *mut Waiter;
229
230 let exchange = queue.compare_exchange(
231 curr_queue,
232 strict::map_addr(me, |q| q | curr_state),
233 Ordering::Release,
234 Ordering::Relaxed,
235 );
236 if let Err(new_queue) = exchange {
237 if strict::addr(new_queue) & STATE_MASK != curr_state {
238 return;
239 }
240 curr_queue = new_queue;
241 continue;
242 }
243
244 while !node.signaled.load(Ordering::Acquire) {
245 thread::park();
246 }
247 break;
248 }
249 }
250
251 // Polyfill of strict provenance from https://crates.io/crates/sptr.
252 //
253 // Use free-standing function rather than a trait to keep things simple and
254 // avoid any potential conflicts with future stabile std API.
255 mod strict {
256 #[must_use]
257 #[inline]
addr<T>(ptr: *mut T) -> usize where T: Sized,258 pub(crate) fn addr<T>(ptr: *mut T) -> usize
259 where
260 T: Sized,
261 {
262 // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
263 // SAFETY: Pointer-to-integer transmutes are valid (if you are okay with losing the
264 // provenance).
265 unsafe { core::mem::transmute(ptr) }
266 }
267
268 #[must_use]
269 #[inline]
with_addr<T>(ptr: *mut T, addr: usize) -> *mut T where T: Sized,270 pub(crate) fn with_addr<T>(ptr: *mut T, addr: usize) -> *mut T
271 where
272 T: Sized,
273 {
274 // FIXME(strict_provenance_magic): I am magic and should be a compiler intrinsic.
275 //
276 // In the mean-time, this operation is defined to be "as if" it was
277 // a wrapping_offset, so we can emulate it as such. This should properly
278 // restore pointer provenance even under today's compiler.
279 let self_addr = self::addr(ptr) as isize;
280 let dest_addr = addr as isize;
281 let offset = dest_addr.wrapping_sub(self_addr);
282
283 // This is the canonical desugarring of this operation,
284 // but `pointer::cast` was only stabilized in 1.38.
285 // self.cast::<u8>().wrapping_offset(offset).cast::<T>()
286 (ptr as *mut u8).wrapping_offset(offset) as *mut T
287 }
288
289 #[must_use]
290 #[inline]
map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T where T: Sized,291 pub(crate) fn map_addr<T>(ptr: *mut T, f: impl FnOnce(usize) -> usize) -> *mut T
292 where
293 T: Sized,
294 {
295 self::with_addr(ptr, f(addr(ptr)))
296 }
297 }
298
299 // These test are snatched from std as well.
300 #[cfg(test)]
301 mod tests {
302 use std::panic;
303 use std::{sync::mpsc::channel, thread};
304
305 use super::OnceCell;
306
307 impl<T> OnceCell<T> {
init(&self, f: impl FnOnce() -> T)308 fn init(&self, f: impl FnOnce() -> T) {
309 enum Void {}
310 let _ = self.initialize(|| Ok::<T, Void>(f()));
311 }
312 }
313
314 #[test]
smoke_once()315 fn smoke_once() {
316 static O: OnceCell<()> = OnceCell::new();
317 let mut a = 0;
318 O.init(|| a += 1);
319 assert_eq!(a, 1);
320 O.init(|| a += 1);
321 assert_eq!(a, 1);
322 }
323
324 #[test]
stampede_once()325 fn stampede_once() {
326 static O: OnceCell<()> = OnceCell::new();
327 static mut RUN: bool = false;
328
329 let (tx, rx) = channel();
330 for _ in 0..10 {
331 let tx = tx.clone();
332 thread::spawn(move || {
333 for _ in 0..4 {
334 thread::yield_now()
335 }
336 unsafe {
337 O.init(|| {
338 assert!(!RUN);
339 RUN = true;
340 });
341 assert!(RUN);
342 }
343 tx.send(()).unwrap();
344 });
345 }
346
347 unsafe {
348 O.init(|| {
349 assert!(!RUN);
350 RUN = true;
351 });
352 assert!(RUN);
353 }
354
355 for _ in 0..10 {
356 rx.recv().unwrap();
357 }
358 }
359
360 #[test]
poison_bad()361 fn poison_bad() {
362 static O: OnceCell<()> = OnceCell::new();
363
364 // poison the once
365 let t = panic::catch_unwind(|| {
366 O.init(|| panic!());
367 });
368 assert!(t.is_err());
369
370 // we can subvert poisoning, however
371 let mut called = false;
372 O.init(|| {
373 called = true;
374 });
375 assert!(called);
376
377 // once any success happens, we stop propagating the poison
378 O.init(|| {});
379 }
380
381 #[test]
wait_for_force_to_finish()382 fn wait_for_force_to_finish() {
383 static O: OnceCell<()> = OnceCell::new();
384
385 // poison the once
386 let t = panic::catch_unwind(|| {
387 O.init(|| panic!());
388 });
389 assert!(t.is_err());
390
391 // make sure someone's waiting inside the once via a force
392 let (tx1, rx1) = channel();
393 let (tx2, rx2) = channel();
394 let t1 = thread::spawn(move || {
395 O.init(|| {
396 tx1.send(()).unwrap();
397 rx2.recv().unwrap();
398 });
399 });
400
401 rx1.recv().unwrap();
402
403 // put another waiter on the once
404 let t2 = thread::spawn(|| {
405 let mut called = false;
406 O.init(|| {
407 called = true;
408 });
409 assert!(!called);
410 });
411
412 tx2.send(()).unwrap();
413
414 assert!(t1.join().is_ok());
415 assert!(t2.join().is_ok());
416 }
417
418 #[test]
419 #[cfg(target_pointer_width = "64")]
test_size()420 fn test_size() {
421 use std::mem::size_of;
422
423 assert_eq!(size_of::<OnceCell<u32>>(), 4 * size_of::<u32>());
424 }
425 }
426