• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <c10/cuda/CUDACachingAllocator.h>
2 #include <gtest/gtest.h>
3 
4 static int segmentAllocCalled = 0;
5 static int segmentFreeCalled = 0;
6 
SegmentAllocTraceTracker(const c10::cuda::CUDACachingAllocator::TraceEntry & te)7 static void SegmentAllocTraceTracker(
8     const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
9   if (te.action_ ==
10       c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_ALLOC) {
11     segmentAllocCalled++;
12   }
13 }
14 
SegmentFreeTraceTracker(const c10::cuda::CUDACachingAllocator::TraceEntry & te)15 static void SegmentFreeTraceTracker(
16     const c10::cuda::CUDACachingAllocator::TraceEntry& te) {
17   if (te.action_ ==
18       c10::cuda::CUDACachingAllocator::TraceEntry::Action::SEGMENT_FREE) {
19     segmentFreeCalled++;
20   }
21 }
22 
allocateLargeBuffer()23 static void allocateLargeBuffer() {
24   const auto _500mb = 500 * 1024 * 1024;
25   auto* allocator = c10::cuda::CUDACachingAllocator::get();
26   auto buffer = allocator->allocate(_500mb);
27 }
28 
TEST(AllocatorTraceTracker,TrackMallocFree)29 TEST(AllocatorTraceTracker, TrackMallocFree) {
30   c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
31       &SegmentAllocTraceTracker);
32   c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker(
33       &SegmentFreeTraceTracker);
34 
35   // Expect to trigger segment allocation for large buffer
36   // and expect the buffer would be marked as inactive when return from
37   // allocateLargeBuffer and be freed when calling emptyCache
38   allocateLargeBuffer();
39   ASSERT_EQ(segmentAllocCalled, 1);
40 
41   // Expect allocated buffer has been released back to allocator, thus empty
42   // cache would trigger segment free
43   c10::cuda::CUDACachingAllocator::emptyCache();
44   ASSERT_EQ(segmentFreeCalled, 1);
45 }
46 
main(int argc,char * argv[])47 int main(int argc, char* argv[]) {
48   ::testing::InitGoogleTest(&argc, argv);
49   c10::cuda::CUDACachingAllocator::init(1);
50   return RUN_ALL_TESTS();
51 }
52