• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2016 Amanieu d'Antras
2 //
3 // Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4 // http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5 // http://opensource.org/licenses/MIT>, at your option. This file may not be
6 // copied, modified, or distributed except according to those terms.
7 
8 use crate::raw_mutex::RawMutex;
9 use core::num::NonZeroUsize;
10 use lock_api::{self, GetThreadId};
11 
12 /// Implementation of the `GetThreadId` trait for `lock_api::ReentrantMutex`.
13 pub struct RawThreadId;
14 
15 unsafe impl GetThreadId for RawThreadId {
16     const INIT: RawThreadId = RawThreadId;
17 
nonzero_thread_id(&self) -> NonZeroUsize18     fn nonzero_thread_id(&self) -> NonZeroUsize {
19         // The address of a thread-local variable is guaranteed to be unique to the
20         // current thread, and is also guaranteed to be non-zero. The variable has to have a
21         // non-zero size to guarantee it has a unique address for each thread.
22         thread_local!(static KEY: u8 = 0);
23         KEY.with(|x| {
24             NonZeroUsize::new(x as *const _ as usize)
25                 .expect("thread-local variable address is null")
26         })
27     }
28 }
29 
30 /// A mutex which can be recursively locked by a single thread.
31 ///
32 /// This type is identical to `Mutex` except for the following points:
33 ///
34 /// - Locking multiple times from the same thread will work correctly instead of
35 ///   deadlocking.
36 /// - `ReentrantMutexGuard` does not give mutable references to the locked data.
37 ///   Use a `RefCell` if you need this.
38 ///
39 /// See [`Mutex`](crate::Mutex) for more details about the underlying mutex
40 /// primitive.
41 pub type ReentrantMutex<T> = lock_api::ReentrantMutex<RawMutex, RawThreadId, T>;
42 
43 /// Creates a new reentrant mutex in an unlocked state ready for use.
44 ///
45 /// This allows creating a reentrant mutex in a constant context on stable Rust.
const_reentrant_mutex<T>(val: T) -> ReentrantMutex<T>46 pub const fn const_reentrant_mutex<T>(val: T) -> ReentrantMutex<T> {
47     ReentrantMutex::const_new(
48         <RawMutex as lock_api::RawMutex>::INIT,
49         <RawThreadId as lock_api::GetThreadId>::INIT,
50         val,
51     )
52 }
53 
54 /// An RAII implementation of a "scoped lock" of a reentrant mutex. When this structure
55 /// is dropped (falls out of scope), the lock will be unlocked.
56 ///
57 /// The data protected by the mutex can be accessed through this guard via its
58 /// `Deref` implementation.
59 pub type ReentrantMutexGuard<'a, T> = lock_api::ReentrantMutexGuard<'a, RawMutex, RawThreadId, T>;
60 
61 /// An RAII mutex guard returned by `ReentrantMutexGuard::map`, which can point to a
62 /// subfield of the protected data.
63 ///
64 /// The main difference between `MappedReentrantMutexGuard` and `ReentrantMutexGuard` is that the
65 /// former doesn't support temporarily unlocking and re-locking, since that
66 /// could introduce soundness issues if the locked object is modified by another
67 /// thread.
68 pub type MappedReentrantMutexGuard<'a, T> =
69     lock_api::MappedReentrantMutexGuard<'a, RawMutex, RawThreadId, T>;
70 
71 #[cfg(test)]
72 mod tests {
73     use crate::ReentrantMutex;
74     use crate::ReentrantMutexGuard;
75     use std::cell::RefCell;
76     use std::sync::mpsc::channel;
77     use std::sync::Arc;
78     use std::thread;
79 
80     #[cfg(feature = "serde")]
81     use bincode::{deserialize, serialize};
82 
83     #[test]
smoke()84     fn smoke() {
85         let m = ReentrantMutex::new(2);
86         {
87             let a = m.lock();
88             {
89                 let b = m.lock();
90                 {
91                     let c = m.lock();
92                     assert_eq!(*c, 2);
93                 }
94                 assert_eq!(*b, 2);
95             }
96             assert_eq!(*a, 2);
97         }
98     }
99 
100     #[test]
is_mutex()101     fn is_mutex() {
102         let m = Arc::new(ReentrantMutex::new(RefCell::new(0)));
103         let m2 = m.clone();
104         let lock = m.lock();
105         let child = thread::spawn(move || {
106             let lock = m2.lock();
107             assert_eq!(*lock.borrow(), 4950);
108         });
109         for i in 0..100 {
110             let lock = m.lock();
111             *lock.borrow_mut() += i;
112         }
113         drop(lock);
114         child.join().unwrap();
115     }
116 
117     #[test]
trylock_works()118     fn trylock_works() {
119         let m = Arc::new(ReentrantMutex::new(()));
120         let m2 = m.clone();
121         let _lock = m.try_lock();
122         let _lock2 = m.try_lock();
123         thread::spawn(move || {
124             let lock = m2.try_lock();
125             assert!(lock.is_none());
126         })
127         .join()
128         .unwrap();
129         let _lock3 = m.try_lock();
130     }
131 
132     #[test]
test_reentrant_mutex_debug()133     fn test_reentrant_mutex_debug() {
134         let mutex = ReentrantMutex::new(vec![0u8, 10]);
135 
136         assert_eq!(format!("{:?}", mutex), "ReentrantMutex { data: [0, 10] }");
137     }
138 
139     #[test]
test_reentrant_mutex_bump()140     fn test_reentrant_mutex_bump() {
141         let mutex = Arc::new(ReentrantMutex::new(()));
142         let mutex2 = mutex.clone();
143 
144         let mut guard = mutex.lock();
145 
146         let (tx, rx) = channel();
147 
148         thread::spawn(move || {
149             let _guard = mutex2.lock();
150             tx.send(()).unwrap();
151         });
152 
153         // `bump()` repeatedly until the thread starts up and requests the lock
154         while rx.try_recv().is_err() {
155             ReentrantMutexGuard::bump(&mut guard);
156         }
157     }
158 
159     #[cfg(feature = "serde")]
160     #[test]
test_serde()161     fn test_serde() {
162         let contents: Vec<u8> = vec![0, 1, 2];
163         let mutex = ReentrantMutex::new(contents.clone());
164 
165         let serialized = serialize(&mutex).unwrap();
166         let deserialized: ReentrantMutex<Vec<u8>> = deserialize(&serialized).unwrap();
167 
168         assert_eq!(*(mutex.lock()), *(deserialized.lock()));
169         assert_eq!(contents, *(deserialized.lock()));
170     }
171 }
172