#include #include #include #include #include #include namespace torch::cuda::CUDAPluggableAllocator { CUDAPluggableAllocatorDeleterContext::CUDAPluggableAllocatorDeleterContext( std::function free_fn, void* data, size_t size, int device, cudaStream_t stream) : free_fn_(free_fn), data_(data), size_(size), device_(device), stream_(stream) {} void CUDAPluggableAllocatorDeleterContext::free() { free_fn_(data_, size_, device_, stream_); delete this; } int device_count = 0; void custom_raw_deleter(void* ptr); _AllocationMetadata::_AllocationMetadata() : size(0), device_idx(-1), stream{} {} _AllocationMetadata::_AllocationMetadata( size_t size, c10::DeviceIndex device_idx, cudaStream_t stream) : size(size), device_idx(device_idx), stream(stream) {} // This is a fast API to just register allocators // based on function pointers (ie. external .so libraries) // This avoids having to link against libtorch for C++ based custom allocators // And also use this from python CUDAPluggableAllocator::CUDAPluggableAllocator( std::function alloc_fn, std::function free_fn) : alloc_fn_(std::move(alloc_fn)), free_fn_(std::move(free_fn)) {} CUDAPluggableAllocator::CUDAPluggableAllocator(CUDAPluggableAllocator& other) : alloc_fn_(other.alloc_fn_), free_fn_(other.free_fn_), init_fn_(other.init_fn_), reset_fn_(other.reset_fn_), memory_fraction_fn_(other.memory_fraction_fn_), base_alloc_fn_(other.base_alloc_fn_), record_stream_fn_(other.record_stream_fn_), begin_allocate_to_pool_fn_(other.begin_allocate_to_pool_fn_), end_allocate_to_pool_fn_(other.end_allocate_to_pool_fn_), relase_pool_fn_(other.relase_pool_fn_) {} void CUDAPluggableAllocator::set_init_fn(std::function init_fn) { init_fn_ = std::move(init_fn); } void CUDAPluggableAllocator::set_reset_fn(std::function reset_fn) { reset_fn_ = std::move(reset_fn); } void CUDAPluggableAllocator::set_memory_fraction_fn( std::function memory_fraction_fn) { memory_fraction_fn_ = std::move(memory_fraction_fn); } void CUDAPluggableAllocator::set_base_alloc_fn( std::function base_alloc_fn) { base_alloc_fn_ = std::move(base_alloc_fn); } void CUDAPluggableAllocator::set_record_stream_fn( std::function record_stream_fn) { record_stream_fn_ = std::move(record_stream_fn); } void CUDAPluggableAllocator::set_begin_allocate_to_pool( std::function< void(int, c10::cuda::MempoolId_t, std::function)> capture_begin_fn) { begin_allocate_to_pool_fn_ = std::move(capture_begin_fn); } void CUDAPluggableAllocator::set_end_allocate_to_pool_fn( std::function capture_about_to_end_fn) { end_allocate_to_pool_fn_ = std::move(capture_about_to_end_fn); } void CUDAPluggableAllocator::set_release_pool( std::function capture_destroy_fn) { relase_pool_fn_ = std::move(capture_destroy_fn); } void* CUDAPluggableAllocator::malloc( size_t size, c10::DeviceIndex device, cudaStream_t stream) { void* r = alloc_fn_(size, device, stream); { const std::lock_guard lock(allocator_mutex_); allocation_metadata_.emplace(r, _AllocationMetadata(size, device, stream)); } return r; } c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) { c10::DeviceIndex device = -1; C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); void* r = this->malloc(size, device, stream); auto* ctx = new CUDAPluggableAllocatorDeleterContext( free_fn_, r, size, device, stream); c10::DataPtr data_ptr = { r, ctx, raw_deleter(), c10::Device(c10::DeviceType::CUDA, device)}; return data_ptr; } c10::DeleterFnPtr CUDAPluggableAllocator::raw_deleter() const { return &custom_raw_deleter; } void* CUDAPluggableAllocator::raw_alloc(size_t nbytes) { c10::DeviceIndex device = -1; C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); return malloc(nbytes, device, stream); } void* CUDAPluggableAllocator::raw_alloc_with_stream( size_t nbytes, cudaStream_t stream) { c10::DeviceIndex device = -1; C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); return malloc(nbytes, device, stream); } void CUDAPluggableAllocator::raw_delete(void* ptr) { cudaStream_t stream{}; c10::DeviceIndex device_idx = -1; size_t size = 0; { const std::lock_guard lock(allocator_mutex_); TORCH_CHECK( allocation_metadata_.count(ptr), "Trying to free a pointer not allocated here"); _AllocationMetadata& metadata = allocation_metadata_[ptr]; size = metadata.size; device_idx = metadata.device_idx; stream = metadata.stream; allocation_metadata_.erase(ptr); } free_fn_(ptr, size, device_idx, stream); } void CUDAPluggableAllocator::init(int device_count) { if (init_fn_) { init_fn_(device_count); } initialized_ = true; } bool CUDAPluggableAllocator::initialized() { return initialized_; } void CUDAPluggableAllocator::setMemoryFraction( double fraction, c10::DeviceIndex device) { if (memory_fraction_fn_) { memory_fraction_fn_(fraction, device); } } void CUDAPluggableAllocator::emptyCache() { if (reset_fn_) { return reset_fn_(); } } void CUDAPluggableAllocator::cacheInfo( c10::DeviceIndex device, size_t* largestBlock) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support cacheInfo. " "If you need it, please file an issue describing your use case."); } void* CUDAPluggableAllocator::getBaseAllocation(void* ptr, size_t* size) { if (base_alloc_fn_) { return base_alloc_fn_(ptr, size); } else { return ptr; } } void CUDAPluggableAllocator::recordStream( const c10::DataPtr& ptr, streamType stream) { if (record_stream_fn_) { record_stream_fn_(ptr.get(), stream); } } c10::CachingDeviceAllocator::DeviceStats CUDAPluggableAllocator::getDeviceStats( c10::DeviceIndex device) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support getDeviceStats. " "If you need it, please file an issue describing your use case."); } void CUDAPluggableAllocator::resetAccumulatedStats(c10::DeviceIndex device) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support resetAccumulatedStats. " "If you need it, please file an issue describing your use case."); } void CUDAPluggableAllocator::resetPeakStats(c10::DeviceIndex device) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support resetPeakStats. " "If you need it, please file an issue describing your use case."); } c10::cuda::CUDACachingAllocator::SnapshotInfo CUDAPluggableAllocator:: snapshot() { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support snapshot. " "If you need it, please file an issue describing your use case."); } c10::cuda::CUDACachingAllocator::ShareableHandle CUDAPluggableAllocator:: shareIpcHandle(void* ptr) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support shareIPcHandle. " "If you need it, please file an issue describing your use case."); } std::shared_ptr CUDAPluggableAllocator::getIpcDevPtr(std::string handle) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support getIpcDevPtr. " "If you need it, please file an issue describing your use case."); } // CUDAGraph interactions void CUDAPluggableAllocator::beginAllocateToPool( c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id, std::function filter) { if (begin_allocate_to_pool_fn_) { begin_allocate_to_pool_fn_(device, mempool_id, std::move(filter)); } } void CUDAPluggableAllocator::endAllocateToPool( c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id) { if (end_allocate_to_pool_fn_) { end_allocate_to_pool_fn_(device, mempool_id); } } void CUDAPluggableAllocator::releasePool( c10::DeviceIndex device, c10::cuda::MempoolId_t mempool_id) { if (relase_pool_fn_) { relase_pool_fn_(device, mempool_id); } } void CUDAPluggableAllocator::recordHistory( bool enabled, c10::cuda::CUDACachingAllocator::CreateContextFn context_recorder, size_t alloc_trace_max_entries, c10::cuda::CUDACachingAllocator::RecordContext when) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support recordHistory. " "If you need it, please file an issue describing your use case."); } void CUDAPluggableAllocator::attachOutOfMemoryObserver( c10::cuda::CUDACachingAllocator::OutOfMemoryObserver observer) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support attachOutOfMemoryObserver. " "If you need it, please file an issue describing your use case."); } void CUDAPluggableAllocator::attachAllocatorTraceTracker( c10::cuda::CUDACachingAllocator::AllocatorTraceTracker tracker) { TORCH_CHECK( false, "CUDAPluggableAllocator does not support attachAllocatorTraceTracker. " "attachAllocatorTraceTracker is only used inside Pytorch."); } std::shared_ptr CUDAPluggableAllocator::getCheckpointState( c10::DeviceIndex device, at::cuda::MempoolId_t id) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support getCheckpointState. " "If you need it, please file an issue describing your use case."); } c10::cuda::CUDACachingAllocator::CheckpointDelta CUDAPluggableAllocator:: setCheckpointPoolState( c10::DeviceIndex device, std::shared_ptr pps) { TORCH_CHECK( false, "CUDAPluggableAllocator does not yet support setCheckpointPoolState. " "If you need it, please file an issue describing your use case."); } void CUDAPluggableAllocator::enablePeerAccess( c10::DeviceIndex dev, c10::DeviceIndex dev_to_access) { c10::cuda::CUDAGuard device_guard(dev); cudaError_t err = cudaDeviceEnablePeerAccess(dev_to_access, 0); if (err == cudaErrorPeerAccessAlreadyEnabled) { // ignore and clear the error if access was already enabled (void)cudaGetLastError(); } else { C10_CUDA_CHECK(err); } } cudaError_t CUDAPluggableAllocator::memcpyAsync( void* dst, int dstDevice, const void* src, int srcDevice, size_t count, cudaStream_t stream, bool p2p_enabled) { return cudaMemcpyAsync(dst, src, count, cudaMemcpyDeviceToDevice, stream); } std::string CUDAPluggableAllocator::name() { return "pluggable"; } void CUDAPluggableAllocator::copy_data( void* dest, const void* src, std::size_t count) const { C10_CUDA_CHECK( cudaMemcpy(dest, src, count, cudaMemcpyKind::cudaMemcpyDeviceToDevice)); } std::shared_ptr current_custom_allocator; std::shared_ptr getCurrentAllocator() { return current_custom_allocator; } // TODO: add more functions in the argument std::shared_ptr createCustomAllocator( std::function alloc_fn, std::function free_fn) { std::shared_ptr allocator( new CUDAPluggableAllocator(std::move(alloc_fn), std::move(free_fn))); allocator->init(device_count); return allocator; } void changeCurrentAllocator( const std::shared_ptr& allocator) { TORCH_CHECK( !c10::cuda::CUDACachingAllocator::allocator.load()->initialized(), "Can't swap an already initialized allocator"); c10::cuda::CUDACachingAllocator::allocator.store(allocator.get()); current_custom_allocator = allocator; } void custom_raw_deleter(void* ctx) { reinterpret_cast(ctx)->free(); } } // namespace torch::cuda::CUDAPluggableAllocator