• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <c10/cuda/CUDACachingAllocator.h>
5 
6 #include <ATen/test/allocator_clone_test.h>
7 
8 #include <torch/csrc/cuda/CUDAPluggableAllocator.h>
9 
TEST(AllocatorTestCUDA,test_clone)10 TEST(AllocatorTestCUDA, test_clone) {
11   test_allocator_clone(c10::cuda::CUDACachingAllocator::get());
12 }
13 
14 static int called_dummy_free_0 = 0;
15 static int called_dummy_free_1 = 0;
16 
dummy_alloc_0(size_t size,int device,void * stream)17 void* dummy_alloc_0(size_t size, int device, void* stream) {return nullptr;}
dummy_free_0(void * data,size_t size,int device,void * stream)18 void dummy_free_0(void* data, size_t size, int device, void* stream) {
19   called_dummy_free_0++;
20 }
dummy_free_1(void * data,size_t size,int device,void * stream)21 void dummy_free_1(void* data, size_t size, int device, void* stream) {
22   called_dummy_free_1++;
23 }
24 
25 // Tests that data_ptrs have their respective deleters
26 // when mixing allocators
TEST(AllocatorTestCUDA,test_pluggable_allocator_deleters)27 TEST(AllocatorTestCUDA, test_pluggable_allocator_deleters) {
28   // Create a tensor with dummy_allocator_0, where dummy_free_0 is the deleter
29   auto dummy_allocator_0 = torch::cuda::CUDAPluggableAllocator::createCustomAllocator(dummy_alloc_0, dummy_free_0);
30   c10::cuda::CUDACachingAllocator::allocator.store(dummy_allocator_0.get());
31   at::Tensor a = at::empty({0}, at::TensorOptions().device(at::kCUDA));
32 
33   // Create a tensor with dummy_allocator_1, where dummy_free_1 is the deleter
34   auto dummy_allocator_1 = torch::cuda::CUDAPluggableAllocator::createCustomAllocator(dummy_alloc_0, dummy_free_1);
35   c10::cuda::CUDACachingAllocator::allocator.store(dummy_allocator_1.get());
36   at::Tensor b = at::empty({0}, at::TensorOptions().device(at::kCUDA));
37 
38   // Manually use a's deleter
39   auto* ctx = a.storage().data_ptr().get_context();
40   a.storage().data_ptr().get_deleter()(ctx);
41   a.storage().mutable_data_ptr().release_context();
42 
43   // a's deleter is dummy_free_0
44   // dummy_free_0 should be called above, so called_dummy_free_0 should be 1
45   ASSERT_TRUE(called_dummy_free_0 == 1);
46 
47   // Manually use b's deleter
48   ctx = b.storage().data_ptr().get_context();
49   b.storage().data_ptr().get_deleter()(ctx);
50   b.storage().mutable_data_ptr().release_context();
51 
52   // b's deleter is dummy_free_1
53   // dummy_free_1 should be called above, so called_dummy_free_1 should be 1
54   ASSERT_TRUE(called_dummy_free_1 == 1);
55 }
56