1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <c10/cuda/CUDACachingAllocator.h>
5
6 #include <ATen/test/allocator_clone_test.h>
7
8 #include <torch/csrc/cuda/CUDAPluggableAllocator.h>
9
TEST(AllocatorTestCUDA,test_clone)10 TEST(AllocatorTestCUDA, test_clone) {
11 test_allocator_clone(c10::cuda::CUDACachingAllocator::get());
12 }
13
14 static int called_dummy_free_0 = 0;
15 static int called_dummy_free_1 = 0;
16
dummy_alloc_0(size_t size,int device,void * stream)17 void* dummy_alloc_0(size_t size, int device, void* stream) {return nullptr;}
dummy_free_0(void * data,size_t size,int device,void * stream)18 void dummy_free_0(void* data, size_t size, int device, void* stream) {
19 called_dummy_free_0++;
20 }
dummy_free_1(void * data,size_t size,int device,void * stream)21 void dummy_free_1(void* data, size_t size, int device, void* stream) {
22 called_dummy_free_1++;
23 }
24
25 // Tests that data_ptrs have their respective deleters
26 // when mixing allocators
TEST(AllocatorTestCUDA,test_pluggable_allocator_deleters)27 TEST(AllocatorTestCUDA, test_pluggable_allocator_deleters) {
28 // Create a tensor with dummy_allocator_0, where dummy_free_0 is the deleter
29 auto dummy_allocator_0 = torch::cuda::CUDAPluggableAllocator::createCustomAllocator(dummy_alloc_0, dummy_free_0);
30 c10::cuda::CUDACachingAllocator::allocator.store(dummy_allocator_0.get());
31 at::Tensor a = at::empty({0}, at::TensorOptions().device(at::kCUDA));
32
33 // Create a tensor with dummy_allocator_1, where dummy_free_1 is the deleter
34 auto dummy_allocator_1 = torch::cuda::CUDAPluggableAllocator::createCustomAllocator(dummy_alloc_0, dummy_free_1);
35 c10::cuda::CUDACachingAllocator::allocator.store(dummy_allocator_1.get());
36 at::Tensor b = at::empty({0}, at::TensorOptions().device(at::kCUDA));
37
38 // Manually use a's deleter
39 auto* ctx = a.storage().data_ptr().get_context();
40 a.storage().data_ptr().get_deleter()(ctx);
41 a.storage().mutable_data_ptr().release_context();
42
43 // a's deleter is dummy_free_0
44 // dummy_free_0 should be called above, so called_dummy_free_0 should be 1
45 ASSERT_TRUE(called_dummy_free_0 == 1);
46
47 // Manually use b's deleter
48 ctx = b.storage().data_ptr().get_context();
49 b.storage().data_ptr().get_deleter()(ctx);
50 b.storage().mutable_data_ptr().release_context();
51
52 // b's deleter is dummy_free_1
53 // dummy_free_1 should be called above, so called_dummy_free_1 should be 1
54 ASSERT_TRUE(called_dummy_free_1 == 1);
55 }
56