#include #include #include namespace c10::impl { void cow::cow_deleter(void* ctx) { static_cast(ctx)->decrement_refcount(); } cow::COWDeleterContext::COWDeleterContext( std::unique_ptr data) : data_(std::move(data)) { // We never wrap a COWDeleterContext. TORCH_INTERNAL_ASSERT(data_.get_deleter() != cow::cow_deleter); } auto cow::COWDeleterContext::increment_refcount() -> void { auto refcount = ++refcount_; TORCH_INTERNAL_ASSERT(refcount > 1); } auto cow::COWDeleterContext::decrement_refcount() -> std::variant { auto refcount = --refcount_; TORCH_INTERNAL_ASSERT(refcount >= 0, refcount); if (refcount == 0) { std::unique_lock lock(mutex_); auto result = std::move(data_); lock.unlock(); delete this; return {std::move(result)}; } return std::shared_lock(mutex_); } cow::COWDeleterContext::~COWDeleterContext() { TORCH_INTERNAL_ASSERT(refcount_ == 0); } } // namespace c10::impl