1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef COUNT_NEW_H
10 #define COUNT_NEW_H
11
12 # include <cstdlib>
13 # include <cassert>
14 # include <new>
15
16 #include "test_macros.h"
17
18 #if defined(TEST_HAS_SANITIZERS)
19 #define DISABLE_NEW_COUNT
20 #endif
21
22 namespace detail
23 {
24 TEST_NORETURN
throw_bad_alloc_helper()25 inline void throw_bad_alloc_helper() {
26 #ifndef TEST_HAS_NO_EXCEPTIONS
27 throw std::bad_alloc();
28 #else
29 std::abort();
30 #endif
31 }
32 }
33
34 class MemCounter
35 {
36 public:
37 // Make MemCounter super hard to accidentally construct or copy.
38 class MemCounterCtorArg_ {};
MemCounter(MemCounterCtorArg_)39 explicit MemCounter(MemCounterCtorArg_) { reset(); }
40
41 private:
42 MemCounter(MemCounter const &);
43 MemCounter & operator=(MemCounter const &);
44
45 public:
46 // All checks return true when disable_checking is enabled.
47 static const bool disable_checking;
48
49 // Disallow any allocations from occurring. Useful for testing that
50 // code doesn't perform any allocations.
51 bool disable_allocations;
52
53 // number of allocations to throw after. Default (unsigned)-1. If
54 // throw_after has the default value it will never be decremented.
55 static const unsigned never_throw_value = static_cast<unsigned>(-1);
56 unsigned throw_after;
57
58 int outstanding_new;
59 int new_called;
60 int delete_called;
61 int aligned_new_called;
62 int aligned_delete_called;
63 std::size_t last_new_size;
64 std::size_t last_new_align;
65 std::size_t last_delete_align;
66
67 int outstanding_array_new;
68 int new_array_called;
69 int delete_array_called;
70 int aligned_new_array_called;
71 int aligned_delete_array_called;
72 std::size_t last_new_array_size;
73 std::size_t last_new_array_align;
74 std::size_t last_delete_array_align;
75
76 public:
newCalled(std::size_t s)77 void newCalled(std::size_t s)
78 {
79 assert(disable_allocations == false);
80 assert(s);
81 if (throw_after == 0) {
82 throw_after = never_throw_value;
83 detail::throw_bad_alloc_helper();
84 } else if (throw_after != never_throw_value) {
85 --throw_after;
86 }
87 ++new_called;
88 ++outstanding_new;
89 last_new_size = s;
90 }
91
alignedNewCalled(std::size_t s,std::size_t a)92 void alignedNewCalled(std::size_t s, std::size_t a) {
93 newCalled(s);
94 ++aligned_new_called;
95 last_new_align = a;
96 }
97
deleteCalled(void * p)98 void deleteCalled(void * p)
99 {
100 assert(p);
101 --outstanding_new;
102 ++delete_called;
103 }
104
alignedDeleteCalled(void * p,std::size_t a)105 void alignedDeleteCalled(void *p, std::size_t a) {
106 deleteCalled(p);
107 ++aligned_delete_called;
108 last_delete_align = a;
109 }
110
newArrayCalled(std::size_t s)111 void newArrayCalled(std::size_t s)
112 {
113 assert(disable_allocations == false);
114 assert(s);
115 if (throw_after == 0) {
116 throw_after = never_throw_value;
117 detail::throw_bad_alloc_helper();
118 } else {
119 // don't decrement throw_after here. newCalled will end up doing that.
120 }
121 ++outstanding_array_new;
122 ++new_array_called;
123 last_new_array_size = s;
124 }
125
alignedNewArrayCalled(std::size_t s,std::size_t a)126 void alignedNewArrayCalled(std::size_t s, std::size_t a) {
127 newArrayCalled(s);
128 ++aligned_new_array_called;
129 last_new_array_align = a;
130 }
131
deleteArrayCalled(void * p)132 void deleteArrayCalled(void * p)
133 {
134 assert(p);
135 --outstanding_array_new;
136 ++delete_array_called;
137 }
138
alignedDeleteArrayCalled(void * p,std::size_t a)139 void alignedDeleteArrayCalled(void * p, std::size_t a) {
140 deleteArrayCalled(p);
141 ++aligned_delete_array_called;
142 last_delete_array_align = a;
143 }
144
disableAllocations()145 void disableAllocations()
146 {
147 disable_allocations = true;
148 }
149
enableAllocations()150 void enableAllocations()
151 {
152 disable_allocations = false;
153 }
154
reset()155 void reset()
156 {
157 disable_allocations = false;
158 throw_after = never_throw_value;
159
160 outstanding_new = 0;
161 new_called = 0;
162 delete_called = 0;
163 aligned_new_called = 0;
164 aligned_delete_called = 0;
165 last_new_size = 0;
166 last_new_align = 0;
167
168 outstanding_array_new = 0;
169 new_array_called = 0;
170 delete_array_called = 0;
171 aligned_new_array_called = 0;
172 aligned_delete_array_called = 0;
173 last_new_array_size = 0;
174 last_new_array_align = 0;
175 }
176
177 public:
checkOutstandingNewEq(int n)178 bool checkOutstandingNewEq(int n) const
179 {
180 return disable_checking || n == outstanding_new;
181 }
182
checkOutstandingNewNotEq(int n)183 bool checkOutstandingNewNotEq(int n) const
184 {
185 return disable_checking || n != outstanding_new;
186 }
187
checkNewCalledEq(int n)188 bool checkNewCalledEq(int n) const
189 {
190 return disable_checking || n == new_called;
191 }
192
checkNewCalledNotEq(int n)193 bool checkNewCalledNotEq(int n) const
194 {
195 return disable_checking || n != new_called;
196 }
197
checkNewCalledGreaterThan(int n)198 bool checkNewCalledGreaterThan(int n) const
199 {
200 return disable_checking || new_called > n;
201 }
202
checkDeleteCalledEq(int n)203 bool checkDeleteCalledEq(int n) const
204 {
205 return disable_checking || n == delete_called;
206 }
207
checkDeleteCalledNotEq(int n)208 bool checkDeleteCalledNotEq(int n) const
209 {
210 return disable_checking || n != delete_called;
211 }
212
checkAlignedNewCalledEq(int n)213 bool checkAlignedNewCalledEq(int n) const
214 {
215 return disable_checking || n == aligned_new_called;
216 }
217
checkAlignedNewCalledNotEq(int n)218 bool checkAlignedNewCalledNotEq(int n) const
219 {
220 return disable_checking || n != aligned_new_called;
221 }
222
checkAlignedNewCalledGreaterThan(int n)223 bool checkAlignedNewCalledGreaterThan(int n) const
224 {
225 return disable_checking || aligned_new_called > n;
226 }
227
checkAlignedDeleteCalledEq(int n)228 bool checkAlignedDeleteCalledEq(int n) const
229 {
230 return disable_checking || n == aligned_delete_called;
231 }
232
checkAlignedDeleteCalledNotEq(int n)233 bool checkAlignedDeleteCalledNotEq(int n) const
234 {
235 return disable_checking || n != aligned_delete_called;
236 }
237
checkLastNewSizeEq(std::size_t n)238 bool checkLastNewSizeEq(std::size_t n) const
239 {
240 return disable_checking || n == last_new_size;
241 }
242
checkLastNewSizeNotEq(std::size_t n)243 bool checkLastNewSizeNotEq(std::size_t n) const
244 {
245 return disable_checking || n != last_new_size;
246 }
247
checkLastNewAlignEq(std::size_t n)248 bool checkLastNewAlignEq(std::size_t n) const
249 {
250 return disable_checking || n == last_new_align;
251 }
252
checkLastNewAlignNotEq(std::size_t n)253 bool checkLastNewAlignNotEq(std::size_t n) const
254 {
255 return disable_checking || n != last_new_align;
256 }
257
checkLastDeleteAlignEq(std::size_t n)258 bool checkLastDeleteAlignEq(std::size_t n) const
259 {
260 return disable_checking || n == last_delete_align;
261 }
262
checkLastDeleteAlignNotEq(std::size_t n)263 bool checkLastDeleteAlignNotEq(std::size_t n) const
264 {
265 return disable_checking || n != last_delete_align;
266 }
267
checkOutstandingArrayNewEq(int n)268 bool checkOutstandingArrayNewEq(int n) const
269 {
270 return disable_checking || n == outstanding_array_new;
271 }
272
checkOutstandingArrayNewNotEq(int n)273 bool checkOutstandingArrayNewNotEq(int n) const
274 {
275 return disable_checking || n != outstanding_array_new;
276 }
277
checkNewArrayCalledEq(int n)278 bool checkNewArrayCalledEq(int n) const
279 {
280 return disable_checking || n == new_array_called;
281 }
282
checkNewArrayCalledNotEq(int n)283 bool checkNewArrayCalledNotEq(int n) const
284 {
285 return disable_checking || n != new_array_called;
286 }
287
checkDeleteArrayCalledEq(int n)288 bool checkDeleteArrayCalledEq(int n) const
289 {
290 return disable_checking || n == delete_array_called;
291 }
292
checkDeleteArrayCalledNotEq(int n)293 bool checkDeleteArrayCalledNotEq(int n) const
294 {
295 return disable_checking || n != delete_array_called;
296 }
297
checkAlignedNewArrayCalledEq(int n)298 bool checkAlignedNewArrayCalledEq(int n) const
299 {
300 return disable_checking || n == aligned_new_array_called;
301 }
302
checkAlignedNewArrayCalledNotEq(int n)303 bool checkAlignedNewArrayCalledNotEq(int n) const
304 {
305 return disable_checking || n != aligned_new_array_called;
306 }
307
checkAlignedNewArrayCalledGreaterThan(int n)308 bool checkAlignedNewArrayCalledGreaterThan(int n) const
309 {
310 return disable_checking || aligned_new_array_called > n;
311 }
312
checkAlignedDeleteArrayCalledEq(int n)313 bool checkAlignedDeleteArrayCalledEq(int n) const
314 {
315 return disable_checking || n == aligned_delete_array_called;
316 }
317
checkAlignedDeleteArrayCalledNotEq(int n)318 bool checkAlignedDeleteArrayCalledNotEq(int n) const
319 {
320 return disable_checking || n != aligned_delete_array_called;
321 }
322
checkLastNewArraySizeEq(std::size_t n)323 bool checkLastNewArraySizeEq(std::size_t n) const
324 {
325 return disable_checking || n == last_new_array_size;
326 }
327
checkLastNewArraySizeNotEq(std::size_t n)328 bool checkLastNewArraySizeNotEq(std::size_t n) const
329 {
330 return disable_checking || n != last_new_array_size;
331 }
332
checkLastNewArrayAlignEq(std::size_t n)333 bool checkLastNewArrayAlignEq(std::size_t n) const
334 {
335 return disable_checking || n == last_new_array_align;
336 }
337
checkLastNewArrayAlignNotEq(std::size_t n)338 bool checkLastNewArrayAlignNotEq(std::size_t n) const
339 {
340 return disable_checking || n != last_new_array_align;
341 }
342 };
343
344 #ifdef DISABLE_NEW_COUNT
345 const bool MemCounter::disable_checking = true;
346 #else
347 const bool MemCounter::disable_checking = false;
348 #endif
349
350 #ifdef _MSC_VER
351 #pragma warning(push)
352 #pragma warning(disable: 4640) // '%s' construction of local static object is not thread safe (/Zc:threadSafeInit-)
353 #endif // _MSC_VER
getGlobalMemCounter()354 inline MemCounter* getGlobalMemCounter() {
355 static MemCounter counter((MemCounter::MemCounterCtorArg_()));
356 return &counter;
357 }
358 #ifdef _MSC_VER
359 #pragma warning(pop)
360 #endif
361
362 MemCounter &globalMemCounter = *getGlobalMemCounter();
363
364 #ifndef DISABLE_NEW_COUNT
new(std::size_t s)365 void* operator new(std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
366 {
367 getGlobalMemCounter()->newCalled(s);
368 void* ret = std::malloc(s);
369 if (ret == nullptr)
370 detail::throw_bad_alloc_helper();
371 return ret;
372 }
373
delete(void * p)374 void operator delete(void* p) TEST_NOEXCEPT
375 {
376 getGlobalMemCounter()->deleteCalled(p);
377 std::free(p);
378 }
379
TEST_THROW_SPEC(std::bad_alloc)380 void* operator new[](std::size_t s) TEST_THROW_SPEC(std::bad_alloc)
381 {
382 getGlobalMemCounter()->newArrayCalled(s);
383 return operator new(s);
384 }
385
386 void operator delete[](void* p) TEST_NOEXCEPT
387 {
388 getGlobalMemCounter()->deleteArrayCalled(p);
389 operator delete(p);
390 }
391
392 #ifndef TEST_HAS_NO_ALIGNED_ALLOCATION
393 #if defined(_LIBCPP_MSVCRT_LIKE) || \
394 (!defined(_LIBCPP_VERSION) && defined(_WIN32))
395 #define USE_ALIGNED_ALLOC
396 #endif
397
new(std::size_t s,std::align_val_t av)398 void* operator new(std::size_t s, std::align_val_t av) TEST_THROW_SPEC(std::bad_alloc) {
399 const std::size_t a = static_cast<std::size_t>(av);
400 getGlobalMemCounter()->alignedNewCalled(s, a);
401 void *ret;
402 #ifdef USE_ALIGNED_ALLOC
403 ret = _aligned_malloc(s, a);
404 #else
405 posix_memalign(&ret, a, s);
406 #endif
407 if (ret == nullptr)
408 detail::throw_bad_alloc_helper();
409 return ret;
410 }
411
delete(void * p,std::align_val_t av)412 void operator delete(void *p, std::align_val_t av) TEST_NOEXCEPT {
413 const std::size_t a = static_cast<std::size_t>(av);
414 getGlobalMemCounter()->alignedDeleteCalled(p, a);
415 if (p) {
416 #ifdef USE_ALIGNED_ALLOC
417 ::_aligned_free(p);
418 #else
419 ::free(p);
420 #endif
421 }
422 }
423
TEST_THROW_SPEC(std::bad_alloc)424 void* operator new[](std::size_t s, std::align_val_t av) TEST_THROW_SPEC(std::bad_alloc) {
425 const std::size_t a = static_cast<std::size_t>(av);
426 getGlobalMemCounter()->alignedNewArrayCalled(s, a);
427 return operator new(s, av);
428 }
429
430 void operator delete[](void *p, std::align_val_t av) TEST_NOEXCEPT {
431 const std::size_t a = static_cast<std::size_t>(av);
432 getGlobalMemCounter()->alignedDeleteArrayCalled(p, a);
433 return operator delete(p, av);
434 }
435
436 #endif // TEST_HAS_NO_ALIGNED_ALLOCATION
437
438 #endif // DISABLE_NEW_COUNT
439
440 struct DisableAllocationGuard {
m_disabledDisableAllocationGuard441 explicit DisableAllocationGuard(bool disable = true) : m_disabled(disable)
442 {
443 // Don't re-disable if already disabled.
444 if (globalMemCounter.disable_allocations == true) m_disabled = false;
445 if (m_disabled) globalMemCounter.disableAllocations();
446 }
447
releaseDisableAllocationGuard448 void release() {
449 if (m_disabled) globalMemCounter.enableAllocations();
450 m_disabled = false;
451 }
452
~DisableAllocationGuardDisableAllocationGuard453 ~DisableAllocationGuard() {
454 release();
455 }
456
457 private:
458 bool m_disabled;
459
460 DisableAllocationGuard(DisableAllocationGuard const&);
461 DisableAllocationGuard& operator=(DisableAllocationGuard const&);
462 };
463
464 struct RequireAllocationGuard {
465 explicit RequireAllocationGuard(std::size_t RequireAtLeast = 1)
m_req_allocRequireAllocationGuard466 : m_req_alloc(RequireAtLeast),
467 m_new_count_on_init(globalMemCounter.new_called),
468 m_outstanding_new_on_init(globalMemCounter.outstanding_new),
469 m_exactly(false)
470 {
471 }
472
requireAtLeastRequireAllocationGuard473 void requireAtLeast(std::size_t N) { m_req_alloc = N; m_exactly = false; }
requireExactlyRequireAllocationGuard474 void requireExactly(std::size_t N) { m_req_alloc = N; m_exactly = true; }
475
~RequireAllocationGuardRequireAllocationGuard476 ~RequireAllocationGuard() {
477 assert(globalMemCounter.checkOutstandingNewEq(static_cast<int>(m_outstanding_new_on_init)));
478 std::size_t Expect = m_new_count_on_init + m_req_alloc;
479 assert(globalMemCounter.checkNewCalledEq(static_cast<int>(Expect)) ||
480 (!m_exactly && globalMemCounter.checkNewCalledGreaterThan(static_cast<int>(Expect))));
481 }
482
483 private:
484 std::size_t m_req_alloc;
485 const std::size_t m_new_count_on_init;
486 const std::size_t m_outstanding_new_on_init;
487 bool m_exactly;
488 RequireAllocationGuard(RequireAllocationGuard const&);
489 RequireAllocationGuard& operator=(RequireAllocationGuard const&);
490 };
491
492 #endif /* COUNT_NEW_H */
493