• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 use std::alloc::{GlobalAlloc, Layout, System};
2 use std::ptr::null_mut;
3 use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
4 
5 use bytes::{Buf, Bytes};
6 
7 #[global_allocator]
8 static LEDGER: Ledger = Ledger::new();
9 
10 const LEDGER_LENGTH: usize = 2048;
11 
12 struct Ledger {
13     alloc_table: [(AtomicPtr<u8>, AtomicUsize); LEDGER_LENGTH],
14 }
15 
16 impl Ledger {
new() -> Self17     const fn new() -> Self {
18         const ELEM: (AtomicPtr<u8>, AtomicUsize) =
19             (AtomicPtr::new(null_mut()), AtomicUsize::new(0));
20         let alloc_table = [ELEM; LEDGER_LENGTH];
21 
22         Self { alloc_table }
23     }
24 
25     /// Iterate over our table until we find an open entry, then insert into said entry
insert(&self, ptr: *mut u8, size: usize)26     fn insert(&self, ptr: *mut u8, size: usize) {
27         for (entry_ptr, entry_size) in self.alloc_table.iter() {
28             // SeqCst is good enough here, we don't care about perf, i just want to be correct!
29             if entry_ptr
30                 .compare_exchange(null_mut(), ptr, Ordering::SeqCst, Ordering::SeqCst)
31                 .is_ok()
32             {
33                 entry_size.store(size, Ordering::SeqCst);
34                 break;
35             }
36         }
37     }
38 
remove(&self, ptr: *mut u8) -> usize39     fn remove(&self, ptr: *mut u8) -> usize {
40         for (entry_ptr, entry_size) in self.alloc_table.iter() {
41             // set the value to be something that will never try and be deallocated, so that we
42             // don't have any chance of a race condition
43             //
44             // dont worry, LEDGER_LENGTH is really long to compensate for us not reclaiming space
45             if entry_ptr
46                 .compare_exchange(
47                     ptr,
48                     invalid_ptr(usize::MAX),
49                     Ordering::SeqCst,
50                     Ordering::SeqCst,
51                 )
52                 .is_ok()
53             {
54                 return entry_size.load(Ordering::SeqCst);
55             }
56         }
57 
58         panic!("Couldn't find a matching entry for {:x?}", ptr);
59     }
60 }
61 
62 unsafe impl GlobalAlloc for Ledger {
alloc(&self, layout: Layout) -> *mut u863     unsafe fn alloc(&self, layout: Layout) -> *mut u8 {
64         let size = layout.size();
65         let ptr = System.alloc(layout);
66         self.insert(ptr, size);
67         ptr
68     }
69 
dealloc(&self, ptr: *mut u8, layout: Layout)70     unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) {
71         let orig_size = self.remove(ptr);
72 
73         if orig_size != layout.size() {
74             panic!(
75                 "bad dealloc: alloc size was {}, dealloc size is {}",
76                 orig_size,
77                 layout.size()
78             );
79         } else {
80             System.dealloc(ptr, layout);
81         }
82     }
83 }
84 
85 #[test]
test_bytes_advance()86 fn test_bytes_advance() {
87     let mut bytes = Bytes::from(vec![10, 20, 30]);
88     bytes.advance(1);
89     drop(bytes);
90 }
91 
92 #[test]
test_bytes_truncate()93 fn test_bytes_truncate() {
94     let mut bytes = Bytes::from(vec![10, 20, 30]);
95     bytes.truncate(2);
96     drop(bytes);
97 }
98 
99 #[test]
test_bytes_truncate_and_advance()100 fn test_bytes_truncate_and_advance() {
101     let mut bytes = Bytes::from(vec![10, 20, 30]);
102     bytes.truncate(2);
103     bytes.advance(1);
104     drop(bytes);
105 }
106 
107 /// Returns a dangling pointer with the given address. This is used to store
108 /// integer data in pointer fields.
109 #[inline]
invalid_ptr<T>(addr: usize) -> *mut T110 fn invalid_ptr<T>(addr: usize) -> *mut T {
111     let ptr = std::ptr::null_mut::<u8>().wrapping_add(addr);
112     debug_assert_eq!(ptr as usize, addr);
113     ptr.cast::<T>()
114 }
115 
116 #[test]
test_bytes_into_vec()117 fn test_bytes_into_vec() {
118     let vec = vec![33u8; 1024];
119 
120     // Test cases where kind == KIND_VEC
121     let b1 = Bytes::from(vec.clone());
122     assert_eq!(Vec::from(b1), vec);
123 
124     // Test cases where kind == KIND_ARC, ref_cnt == 1
125     let b1 = Bytes::from(vec.clone());
126     drop(b1.clone());
127     assert_eq!(Vec::from(b1), vec);
128 
129     // Test cases where kind == KIND_ARC, ref_cnt == 2
130     let b1 = Bytes::from(vec.clone());
131     let b2 = b1.clone();
132     assert_eq!(Vec::from(b1), vec);
133 
134     // Test cases where vtable = SHARED_VTABLE, kind == KIND_ARC, ref_cnt == 1
135     assert_eq!(Vec::from(b2), vec);
136 
137     // Test cases where offset != 0
138     let mut b1 = Bytes::from(vec.clone());
139     let b2 = b1.split_off(20);
140 
141     assert_eq!(Vec::from(b2), vec[20..]);
142     assert_eq!(Vec::from(b1), vec[..20]);
143 }
144