1 // Copyright 2017 Amanieu d'Antras
2 //
3 // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4 // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5 // http://opensource.org/licenses/MIT>, at your option. This file may not be
6 // copied, modified, or distributed except according to those terms.
7
8 //! Per-object thread-local storage
9 //!
10 //! This library provides the `ThreadLocal` type which allows a separate copy of
11 //! an object to be used for each thread. This allows for per-object
12 //! thread-local storage, unlike the standard library's `thread_local!` macro
13 //! which only allows static thread-local storage.
14 //!
15 //! Per-thread objects are not destroyed when a thread exits. Instead, objects
16 //! are only destroyed when the `ThreadLocal` containing them is destroyed.
17 //!
18 //! You can also iterate over the thread-local values of all thread in a
19 //! `ThreadLocal` object using the `iter_mut` and `into_iter` methods. This can
20 //! only be done if you have mutable access to the `ThreadLocal` object, which
21 //! guarantees that you are the only thread currently accessing it.
22 //!
23 //! Note that since thread IDs are recycled when a thread exits, it is possible
24 //! for one thread to retrieve the object of another thread. Since this can only
25 //! occur after a thread has exited this does not lead to any race conditions.
26 //!
27 //! # Examples
28 //!
29 //! Basic usage of `ThreadLocal`:
30 //!
31 //! ```rust
32 //! use thread_local::ThreadLocal;
33 //! let tls: ThreadLocal<u32> = ThreadLocal::new();
34 //! assert_eq!(tls.get(), None);
35 //! assert_eq!(tls.get_or(|| 5), &5);
36 //! assert_eq!(tls.get(), Some(&5));
37 //! ```
38 //!
39 //! Combining thread-local values into a single result:
40 //!
41 //! ```rust
42 //! use thread_local::ThreadLocal;
43 //! use std::sync::Arc;
44 //! use std::cell::Cell;
45 //! use std::thread;
46 //!
47 //! let tls = Arc::new(ThreadLocal::new());
48 //!
49 //! // Create a bunch of threads to do stuff
50 //! for _ in 0..5 {
51 //! let tls2 = tls.clone();
52 //! thread::spawn(move || {
53 //! // Increment a counter to count some event...
54 //! let cell = tls2.get_or(|| Cell::new(0));
55 //! cell.set(cell.get() + 1);
56 //! }).join().unwrap();
57 //! }
58 //!
59 //! // Once all threads are done, collect the counter values and return the
60 //! // sum of all thread-local counter values.
61 //! let tls = Arc::try_unwrap(tls).unwrap();
62 //! let total = tls.into_iter().fold(0, |x, y| x + y.get());
63 //! assert_eq!(total, 5);
64 //! ```
65
66 #![warn(missing_docs)]
67 #![allow(clippy::mutex_atomic)]
68
69 mod cached;
70 mod thread_id;
71 mod unreachable;
72
73 #[allow(deprecated)]
74 pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};
75
76 use std::cell::UnsafeCell;
77 use std::fmt;
78 use std::iter::FusedIterator;
79 use std::mem;
80 use std::mem::MaybeUninit;
81 use std::panic::UnwindSafe;
82 use std::ptr;
83 use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
84 use std::sync::Mutex;
85 use thread_id::Thread;
86 use unreachable::UncheckedResultExt;
87
88 // Use usize::BITS once it has stabilized and the MSRV has been bumped.
89 #[cfg(target_pointer_width = "16")]
90 const POINTER_WIDTH: u8 = 16;
91 #[cfg(target_pointer_width = "32")]
92 const POINTER_WIDTH: u8 = 32;
93 #[cfg(target_pointer_width = "64")]
94 const POINTER_WIDTH: u8 = 64;
95
96 /// The total number of buckets stored in each thread local.
97 const BUCKETS: usize = (POINTER_WIDTH + 1) as usize;
98
99 /// Thread-local variable wrapper
100 ///
101 /// See the [module-level documentation](index.html) for more.
102 pub struct ThreadLocal<T: Send> {
103 /// The buckets in the thread local. The nth bucket contains `2^(n-1)`
104 /// elements. Each bucket is lazily allocated.
105 buckets: [AtomicPtr<Entry<T>>; BUCKETS],
106
107 /// The number of values in the thread local. This can be less than the real number of values,
108 /// but is never more.
109 values: AtomicUsize,
110
111 /// Lock used to guard against concurrent modifications. This is taken when
112 /// there is a possibility of allocating a new bucket, which only occurs
113 /// when inserting values.
114 lock: Mutex<()>,
115 }
116
117 struct Entry<T> {
118 present: AtomicBool,
119 value: UnsafeCell<MaybeUninit<T>>,
120 }
121
122 impl<T> Drop for Entry<T> {
drop(&mut self)123 fn drop(&mut self) {
124 unsafe {
125 if *self.present.get_mut() {
126 ptr::drop_in_place((*self.value.get()).as_mut_ptr());
127 }
128 }
129 }
130 }
131
132 // ThreadLocal is always Sync, even if T isn't
133 unsafe impl<T: Send> Sync for ThreadLocal<T> {}
134
135 impl<T: Send> Default for ThreadLocal<T> {
default() -> ThreadLocal<T>136 fn default() -> ThreadLocal<T> {
137 ThreadLocal::new()
138 }
139 }
140
141 impl<T: Send> Drop for ThreadLocal<T> {
drop(&mut self)142 fn drop(&mut self) {
143 let mut bucket_size = 1;
144
145 // Free each non-null bucket
146 for (i, bucket) in self.buckets.iter_mut().enumerate() {
147 let bucket_ptr = *bucket.get_mut();
148
149 let this_bucket_size = bucket_size;
150 if i != 0 {
151 bucket_size <<= 1;
152 }
153
154 if bucket_ptr.is_null() {
155 continue;
156 }
157
158 unsafe { Box::from_raw(std::slice::from_raw_parts_mut(bucket_ptr, this_bucket_size)) };
159 }
160 }
161 }
162
163 impl<T: Send> ThreadLocal<T> {
164 /// Creates a new empty `ThreadLocal`.
new() -> ThreadLocal<T>165 pub fn new() -> ThreadLocal<T> {
166 Self::with_capacity(2)
167 }
168
169 /// Creates a new `ThreadLocal` with an initial capacity. If less than the capacity threads
170 /// access the thread local it will never reallocate. The capacity may be rounded up to the
171 /// nearest power of two.
with_capacity(capacity: usize) -> ThreadLocal<T>172 pub fn with_capacity(capacity: usize) -> ThreadLocal<T> {
173 let allocated_buckets = capacity
174 .checked_sub(1)
175 .map(|c| usize::from(POINTER_WIDTH) - (c.leading_zeros() as usize) + 1)
176 .unwrap_or(0);
177
178 let mut buckets = [ptr::null_mut(); BUCKETS];
179 let mut bucket_size = 1;
180 for (i, bucket) in buckets[..allocated_buckets].iter_mut().enumerate() {
181 *bucket = allocate_bucket::<T>(bucket_size);
182
183 if i != 0 {
184 bucket_size <<= 1;
185 }
186 }
187
188 ThreadLocal {
189 // Safety: AtomicPtr has the same representation as a pointer and arrays have the same
190 // representation as a sequence of their inner type.
191 buckets: unsafe { mem::transmute(buckets) },
192 values: AtomicUsize::new(0),
193 lock: Mutex::new(()),
194 }
195 }
196
197 /// Returns the element for the current thread, if it exists.
get(&self) -> Option<&T>198 pub fn get(&self) -> Option<&T> {
199 let thread = thread_id::get();
200 self.get_inner(thread)
201 }
202
203 /// Returns the element for the current thread, or creates it if it doesn't
204 /// exist.
get_or<F>(&self, create: F) -> &T where F: FnOnce() -> T,205 pub fn get_or<F>(&self, create: F) -> &T
206 where
207 F: FnOnce() -> T,
208 {
209 unsafe {
210 self.get_or_try(|| Ok::<T, ()>(create()))
211 .unchecked_unwrap_ok()
212 }
213 }
214
215 /// Returns the element for the current thread, or creates it if it doesn't
216 /// exist. If `create` fails, that error is returned and no element is
217 /// added.
get_or_try<F, E>(&self, create: F) -> Result<&T, E> where F: FnOnce() -> Result<T, E>,218 pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E>
219 where
220 F: FnOnce() -> Result<T, E>,
221 {
222 let thread = thread_id::get();
223 match self.get_inner(thread) {
224 Some(x) => Ok(x),
225 None => Ok(self.insert(thread, create()?)),
226 }
227 }
228
get_inner(&self, thread: Thread) -> Option<&T>229 fn get_inner(&self, thread: Thread) -> Option<&T> {
230 let bucket_ptr =
231 unsafe { self.buckets.get_unchecked(thread.bucket) }.load(Ordering::Acquire);
232 if bucket_ptr.is_null() {
233 return None;
234 }
235 unsafe {
236 let entry = &*bucket_ptr.add(thread.index);
237 // Read without atomic operations as only this thread can set the value.
238 if (&entry.present as *const _ as *const bool).read() {
239 Some(&*(&*entry.value.get()).as_ptr())
240 } else {
241 None
242 }
243 }
244 }
245
246 #[cold]
insert(&self, thread: Thread, data: T) -> &T247 fn insert(&self, thread: Thread, data: T) -> &T {
248 // Lock the Mutex to ensure only a single thread is allocating buckets at once
249 let _guard = self.lock.lock().unwrap();
250
251 let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
252
253 let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);
254 let bucket_ptr = if bucket_ptr.is_null() {
255 // Allocate a new bucket
256 let bucket_ptr = allocate_bucket(thread.bucket_size);
257 bucket_atomic_ptr.store(bucket_ptr, Ordering::Release);
258 bucket_ptr
259 } else {
260 bucket_ptr
261 };
262
263 drop(_guard);
264
265 // Insert the new element into the bucket
266 let entry = unsafe { &*bucket_ptr.add(thread.index) };
267 let value_ptr = entry.value.get();
268 unsafe { value_ptr.write(MaybeUninit::new(data)) };
269 entry.present.store(true, Ordering::Release);
270
271 self.values.fetch_add(1, Ordering::Release);
272
273 unsafe { &*(&*value_ptr).as_ptr() }
274 }
275
276 /// Returns an iterator over the local values of all threads in unspecified
277 /// order.
278 ///
279 /// This call can be done safely, as `T` is required to implement [`Sync`].
iter(&self) -> Iter<'_, T> where T: Sync,280 pub fn iter(&self) -> Iter<'_, T>
281 where
282 T: Sync,
283 {
284 Iter {
285 thread_local: self,
286 raw: RawIter::new(),
287 }
288 }
289
290 /// Returns a mutable iterator over the local values of all threads in
291 /// unspecified order.
292 ///
293 /// Since this call borrows the `ThreadLocal` mutably, this operation can
294 /// be done safely---the mutable borrow statically guarantees no other
295 /// threads are currently accessing their associated values.
iter_mut(&mut self) -> IterMut<T>296 pub fn iter_mut(&mut self) -> IterMut<T> {
297 IterMut {
298 thread_local: self,
299 raw: RawIter::new(),
300 }
301 }
302
303 /// Removes all thread-specific values from the `ThreadLocal`, effectively
304 /// reseting it to its original state.
305 ///
306 /// Since this call borrows the `ThreadLocal` mutably, this operation can
307 /// be done safely---the mutable borrow statically guarantees no other
308 /// threads are currently accessing their associated values.
clear(&mut self)309 pub fn clear(&mut self) {
310 *self = ThreadLocal::new();
311 }
312 }
313
314 impl<T: Send> IntoIterator for ThreadLocal<T> {
315 type Item = T;
316 type IntoIter = IntoIter<T>;
317
into_iter(self) -> IntoIter<T>318 fn into_iter(self) -> IntoIter<T> {
319 IntoIter {
320 thread_local: self,
321 raw: RawIter::new(),
322 }
323 }
324 }
325
326 impl<'a, T: Send + Sync> IntoIterator for &'a ThreadLocal<T> {
327 type Item = &'a T;
328 type IntoIter = Iter<'a, T>;
329
into_iter(self) -> Self::IntoIter330 fn into_iter(self) -> Self::IntoIter {
331 self.iter()
332 }
333 }
334
335 impl<'a, T: Send> IntoIterator for &'a mut ThreadLocal<T> {
336 type Item = &'a mut T;
337 type IntoIter = IterMut<'a, T>;
338
into_iter(self) -> IterMut<'a, T>339 fn into_iter(self) -> IterMut<'a, T> {
340 self.iter_mut()
341 }
342 }
343
344 impl<T: Send + Default> ThreadLocal<T> {
345 /// Returns the element for the current thread, or creates a default one if
346 /// it doesn't exist.
get_or_default(&self) -> &T347 pub fn get_or_default(&self) -> &T {
348 self.get_or(Default::default)
349 }
350 }
351
352 impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result353 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
354 write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get())
355 }
356 }
357
358 impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
359
360 #[derive(Debug)]
361 struct RawIter {
362 yielded: usize,
363 bucket: usize,
364 bucket_size: usize,
365 index: usize,
366 }
367 impl RawIter {
368 #[inline]
new() -> Self369 fn new() -> Self {
370 Self {
371 yielded: 0,
372 bucket: 0,
373 bucket_size: 1,
374 index: 0,
375 }
376 }
377
next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal<T>) -> Option<&'a T>378 fn next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal<T>) -> Option<&'a T> {
379 while self.bucket < BUCKETS {
380 let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) };
381 let bucket = bucket.load(Ordering::Acquire);
382
383 if !bucket.is_null() {
384 while self.index < self.bucket_size {
385 let entry = unsafe { &*bucket.add(self.index) };
386 self.index += 1;
387 if entry.present.load(Ordering::Acquire) {
388 self.yielded += 1;
389 return Some(unsafe { &*(&*entry.value.get()).as_ptr() });
390 }
391 }
392 }
393
394 self.next_bucket();
395 }
396 None
397 }
next_mut<'a, T: Send>( &mut self, thread_local: &'a mut ThreadLocal<T>, ) -> Option<&'a mut Entry<T>>398 fn next_mut<'a, T: Send>(
399 &mut self,
400 thread_local: &'a mut ThreadLocal<T>,
401 ) -> Option<&'a mut Entry<T>> {
402 if *thread_local.values.get_mut() == self.yielded {
403 return None;
404 }
405
406 loop {
407 let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) };
408 let bucket = *bucket.get_mut();
409
410 if !bucket.is_null() {
411 while self.index < self.bucket_size {
412 let entry = unsafe { &mut *bucket.add(self.index) };
413 self.index += 1;
414 if *entry.present.get_mut() {
415 self.yielded += 1;
416 return Some(entry);
417 }
418 }
419 }
420
421 self.next_bucket();
422 }
423 }
424
425 #[inline]
next_bucket(&mut self)426 fn next_bucket(&mut self) {
427 if self.bucket != 0 {
428 self.bucket_size <<= 1;
429 }
430 self.bucket += 1;
431 self.index = 0;
432 }
433
size_hint<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>)434 fn size_hint<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
435 let total = thread_local.values.load(Ordering::Acquire);
436 (total - self.yielded, None)
437 }
size_hint_frozen<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>)438 fn size_hint_frozen<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
439 let total = unsafe { *(&thread_local.values as *const AtomicUsize as *const usize) };
440 let remaining = total - self.yielded;
441 (remaining, Some(remaining))
442 }
443 }
444
445 /// Iterator over the contents of a `ThreadLocal`.
446 #[derive(Debug)]
447 pub struct Iter<'a, T: Send + Sync> {
448 thread_local: &'a ThreadLocal<T>,
449 raw: RawIter,
450 }
451
452 impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
453 type Item = &'a T;
next(&mut self) -> Option<Self::Item>454 fn next(&mut self) -> Option<Self::Item> {
455 self.raw.next(self.thread_local)
456 }
size_hint(&self) -> (usize, Option<usize>)457 fn size_hint(&self) -> (usize, Option<usize>) {
458 self.raw.size_hint(self.thread_local)
459 }
460 }
461 impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}
462
463 /// Mutable iterator over the contents of a `ThreadLocal`.
464 pub struct IterMut<'a, T: Send> {
465 thread_local: &'a mut ThreadLocal<T>,
466 raw: RawIter,
467 }
468
469 impl<'a, T: Send> Iterator for IterMut<'a, T> {
470 type Item = &'a mut T;
next(&mut self) -> Option<&'a mut T>471 fn next(&mut self) -> Option<&'a mut T> {
472 self.raw
473 .next_mut(self.thread_local)
474 .map(|entry| unsafe { &mut *(&mut *entry.value.get()).as_mut_ptr() })
475 }
size_hint(&self) -> (usize, Option<usize>)476 fn size_hint(&self) -> (usize, Option<usize>) {
477 self.raw.size_hint_frozen(self.thread_local)
478 }
479 }
480
481 impl<T: Send> ExactSizeIterator for IterMut<'_, T> {}
482 impl<T: Send> FusedIterator for IterMut<'_, T> {}
483
484 // Manual impl so we don't call Debug on the ThreadLocal, as doing so would create a reference to
485 // this thread's value that potentially aliases with a mutable reference we have given out.
486 impl<'a, T: Send + fmt::Debug> fmt::Debug for IterMut<'a, T> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result487 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
488 f.debug_struct("IterMut").field("raw", &self.raw).finish()
489 }
490 }
491
492 /// An iterator that moves out of a `ThreadLocal`.
493 #[derive(Debug)]
494 pub struct IntoIter<T: Send> {
495 thread_local: ThreadLocal<T>,
496 raw: RawIter,
497 }
498
499 impl<T: Send> Iterator for IntoIter<T> {
500 type Item = T;
next(&mut self) -> Option<T>501 fn next(&mut self) -> Option<T> {
502 self.raw.next_mut(&mut self.thread_local).map(|entry| {
503 *entry.present.get_mut() = false;
504 unsafe {
505 std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init()
506 }
507 })
508 }
size_hint(&self) -> (usize, Option<usize>)509 fn size_hint(&self) -> (usize, Option<usize>) {
510 self.raw.size_hint_frozen(&self.thread_local)
511 }
512 }
513
514 impl<T: Send> ExactSizeIterator for IntoIter<T> {}
515 impl<T: Send> FusedIterator for IntoIter<T> {}
516
allocate_bucket<T>(size: usize) -> *mut Entry<T>517 fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
518 Box::into_raw(
519 (0..size)
520 .map(|_| Entry::<T> {
521 present: AtomicBool::new(false),
522 value: UnsafeCell::new(MaybeUninit::uninit()),
523 })
524 .collect(),
525 ) as *mut _
526 }
527
528 #[cfg(test)]
529 mod tests {
530 use super::ThreadLocal;
531 use std::cell::RefCell;
532 use std::sync::atomic::AtomicUsize;
533 use std::sync::atomic::Ordering::Relaxed;
534 use std::sync::Arc;
535 use std::thread;
536
make_create() -> Arc<dyn Fn() -> usize + Send + Sync>537 fn make_create() -> Arc<dyn Fn() -> usize + Send + Sync> {
538 let count = AtomicUsize::new(0);
539 Arc::new(move || count.fetch_add(1, Relaxed))
540 }
541
542 #[test]
same_thread()543 fn same_thread() {
544 let create = make_create();
545 let mut tls = ThreadLocal::new();
546 assert_eq!(None, tls.get());
547 assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
548 assert_eq!(0, *tls.get_or(|| create()));
549 assert_eq!(Some(&0), tls.get());
550 assert_eq!(0, *tls.get_or(|| create()));
551 assert_eq!(Some(&0), tls.get());
552 assert_eq!(0, *tls.get_or(|| create()));
553 assert_eq!(Some(&0), tls.get());
554 assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
555 tls.clear();
556 assert_eq!(None, tls.get());
557 }
558
559 #[test]
different_thread()560 fn different_thread() {
561 let create = make_create();
562 let tls = Arc::new(ThreadLocal::new());
563 assert_eq!(None, tls.get());
564 assert_eq!(0, *tls.get_or(|| create()));
565 assert_eq!(Some(&0), tls.get());
566
567 let tls2 = tls.clone();
568 let create2 = create.clone();
569 thread::spawn(move || {
570 assert_eq!(None, tls2.get());
571 assert_eq!(1, *tls2.get_or(|| create2()));
572 assert_eq!(Some(&1), tls2.get());
573 })
574 .join()
575 .unwrap();
576
577 assert_eq!(Some(&0), tls.get());
578 assert_eq!(0, *tls.get_or(|| create()));
579 }
580
581 #[test]
iter()582 fn iter() {
583 let tls = Arc::new(ThreadLocal::new());
584 tls.get_or(|| Box::new(1));
585
586 let tls2 = tls.clone();
587 thread::spawn(move || {
588 tls2.get_or(|| Box::new(2));
589 let tls3 = tls2.clone();
590 thread::spawn(move || {
591 tls3.get_or(|| Box::new(3));
592 })
593 .join()
594 .unwrap();
595 drop(tls2);
596 })
597 .join()
598 .unwrap();
599
600 let mut tls = Arc::try_unwrap(tls).unwrap();
601
602 let mut v = tls.iter().map(|x| **x).collect::<Vec<i32>>();
603 v.sort_unstable();
604 assert_eq!(vec![1, 2, 3], v);
605
606 let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>();
607 v.sort_unstable();
608 assert_eq!(vec![1, 2, 3], v);
609
610 let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>();
611 v.sort_unstable();
612 assert_eq!(vec![1, 2, 3], v);
613 }
614
615 #[test]
test_drop()616 fn test_drop() {
617 let local = ThreadLocal::new();
618 struct Dropped(Arc<AtomicUsize>);
619 impl Drop for Dropped {
620 fn drop(&mut self) {
621 self.0.fetch_add(1, Relaxed);
622 }
623 }
624
625 let dropped = Arc::new(AtomicUsize::new(0));
626 local.get_or(|| Dropped(dropped.clone()));
627 assert_eq!(dropped.load(Relaxed), 0);
628 drop(local);
629 assert_eq!(dropped.load(Relaxed), 1);
630 }
631
632 #[test]
is_sync()633 fn is_sync() {
634 fn foo<T: Sync>() {}
635 foo::<ThreadLocal<String>>();
636 foo::<ThreadLocal<RefCell<String>>>();
637 }
638 }
639