// Copyright © 2022 Apple Inc. #include #include #include #include #include #include namespace at::mps { C10_DEFINE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback); namespace HeapAllocator { uint64_t BufferBlock::buffer_counter = 0; uint64_t HeapBlock::heap_counter = 0; void MPSHeapAllocatorImpl::init_allocator() { init_buffer_pools(); // debug verbosity flags (see DebugVerbosity enum) static const char* verbosity_str = getenv("PYTORCH_DEBUG_MPS_ALLOCATOR"); m_debug_verbosity = verbosity_str ? strtol(verbosity_str, nullptr, 0) : DebugVerbosity::SILENT; static const char* high_watermark_ratio_str = getenv("PYTORCH_MPS_HIGH_WATERMARK_RATIO"); const double high_watermark_ratio = high_watermark_ratio_str ? strtod(high_watermark_ratio_str, nullptr) : default_high_watermark_ratio; setHighWatermarkRatio(high_watermark_ratio); const double default_low_watermark_ratio = m_device.hasUnifiedMemory ? default_low_watermark_ratio_unified : default_low_watermark_ratio_discrete; static const char* low_watermark_ratio_str = getenv("PYTORCH_MPS_LOW_WATERMARK_RATIO"); const double low_watermark_ratio = low_watermark_ratio_str ? strtod(low_watermark_ratio_str, nullptr) : default_low_watermark_ratio; setLowWatermarkRatio(low_watermark_ratio); } void MPSHeapAllocatorImpl::init_buffer_pools() { // using a container for pools to simplify iterating over them // Pool of large buffers with private storage mode m_pools.emplace(BufferPool::Kind::PRIVATE_LARGE, std::make_unique(m_device, UsageFlags::PRIVATE | UsageFlags::HAZARD)); // Pool of large buffers with shared storage mode m_pools.emplace(BufferPool::Kind::SHARED_LARGE, std::make_unique(m_device, UsageFlags::SHARED | UsageFlags::HAZARD)); // Pool of small buffers with private storage mode m_pools.emplace(BufferPool::Kind::PRIVATE_SMALL, std::make_unique(m_device, UsageFlags::SMALL | UsageFlags::PRIVATE | UsageFlags::HAZARD)); // Pool of small buffers with shared storage mode m_pools.emplace(BufferPool::Kind::SHARED_SMALL, std::make_unique(m_device, UsageFlags::SMALL | UsageFlags::SHARED | UsageFlags::HAZARD)); // Pool of small buffers with shared storage mode used to allocate and copy Scalars // from CPU to Metal buffers (see allocScalarBufferWithValue()). // no Hazard Tracking required for the Scalar pool (synchronized manually). m_pools.emplace(BufferPool::Kind::SCALAR, std::make_unique(m_device, UsageFlags::SMALL | UsageFlags::SHARED | UsageFlags::SCALAR)); } BufferPool& MPSHeapAllocatorImpl::get_pool(size_t requested_size, size_t aligned_size, uint32_t usage) { BufferPool::Kind poolKind; if (usage & UsageFlags::SCALAR) { poolKind = BufferPool::Kind::SCALAR; } else if (requested_size <= kMaxScalarAlloc && m_device.hasUnifiedMemory) { poolKind = BufferPool::Kind::SHARED_SMALL; } else if (aligned_size <= kMaxSmallAlloc) { poolKind = (usage & UsageFlags::SHARED) ? BufferPool::Kind::SHARED_SMALL : BufferPool::Kind::PRIVATE_SMALL; } else { poolKind = (usage & UsageFlags::SHARED) ? BufferPool::Kind::SHARED_LARGE : BufferPool::Kind::PRIVATE_LARGE; } return *m_pools[poolKind]; } size_t MPSHeapAllocatorImpl::get_allocation_size(size_t size, uint32_t usage) const { MTLSizeAndAlign sizeAlign = [m_device heapBufferSizeAndAlignWithLength:size options:HeapBlock::getOptions(usage)]; return BufferBlock::alignUp(sizeAlign.size, sizeAlign.align); } void MPSHeapAllocatorImpl::setHighWatermarkRatio(double ratio) { TORCH_CHECK(ratio >= 0.0 && ratio <= default_high_watermark_upper_bound, "invalid high watermark ratio ", ratio); m_max_total_allowed_size = (ratio == 0.0) ? std::numeric_limits::max() : static_cast(ratio * (double)max_device_size()); if (m_debug_verbosity & DebugVerbosity::PROFILING) { std::cerr << "\nHigh watermark memory allocation limit: " << (ratio == 0.0 ? "unlimited" : format_size(m_max_total_allowed_size)) << "\n"; } m_high_watermark_ratio = ratio; } void MPSHeapAllocatorImpl::setLowWatermarkRatio(double ratio) { // used for comparison with lower_watermark_ratio const double high_watermark_limit = m_high_watermark_ratio == 0.0 ? default_high_watermark_upper_bound : m_high_watermark_ratio; TORCH_CHECK(ratio >= 0.0 && ratio <= high_watermark_limit, "invalid low watermark ratio ", ratio); // we use this to detect if there's memory pressure m_low_watermark_limit = (ratio == 0.0) ? std::numeric_limits::max() : static_cast(ratio * (double)max_device_size()); if (m_debug_verbosity & DebugVerbosity::PROFILING) { std::cerr << "Low watermark memory allocation limit: " << (ratio == 0.0 ? "unlimited" : format_size(m_low_watermark_limit)) << "\n"; } m_low_watermark_ratio = ratio; } HeapBlock* MPSHeapAllocatorImpl::get_free_heap(AllocParams& params) { BufferPool& pool = *params.pool; HeapBlock* heap_block = nullptr; HeapBlock search_key(params.size()); auto it = pool.heaps.lower_bound(&search_key); if (it == pool.heaps.end()) { heap_block = HeapBlock::createHeapBlock(params, pool.device, pool.usage); if (heap_block) { m_total_allocated_memory += heap_block->size.total; if (m_debug_verbosity & DebugVerbosity::ALLOCATIONS) { std::cerr << "\nAllocated " << ((pool.usage & UsageFlags::SHARED) ? "shared" : "private") << " heap #" << heap_block->heap_id << " of size " << format_size(heap_block->size.total) << " (#heaps: " << (pool.heaps.size() + 1) << ", current allocated: " << format_size(current_allocated_size()) << ")\n"; } } } else { heap_block = *it; // remove and re-insert heap in the set later after a buffer is created. // this ensures updating the order of heaps based on their new available sizes pool.heaps.erase(it); } return heap_block; } bool MPSHeapAllocatorImpl::alloc_buffer(AllocParams& params) { if (m_max_total_allowed_size != std::numeric_limits::max() && current_allocated_size() + params.size() > m_max_total_allowed_size) { return false; } HeapBlock* heap = get_free_heap(params); if (!heap) { return false; // this will cause releasing pool buffers to free up memory } BufferPool& pool = *params.pool; id buffer = heap->newMTLBuffer(params.size(), pool.usage); // this should never happen as the backing memory (i.e., heap) was allocated successfully. TORCH_INTERNAL_ASSERT(buffer); // insert heap after a buffer was created on it to update the order of heap's set pool.heaps.insert(heap); params.buffer_block = new BufferBlock(params.size(), params.requested_size, buffer, heap); m_allocated_buffers[params.buffer_block->buffer] = params.buffer_block; pool.allocated_size += params.size(); pool.n_buffers++; if ((m_debug_verbosity & DebugVerbosity::ALLOCATIONS) && (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { std::cerr << "Allocated " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private") << ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #" << params.buffer_block->buf_id << " of size " << format_size(params.size()) << " at " << params.buffer_block->buffer << " from heap #" << heap->heap_id << " (requested: " << format_size(params.requested_size) << ", heap: " << format_size(heap->size.available) << ", total: " << format_size(m_total_allocated_memory) << ")\n"; } return true; } bool MPSHeapAllocatorImpl::get_free_buffer(AllocParams& params) { // this helps to monitor "implicit" allocations from MPS backend and to prevent OOM and system failure. if (m_high_watermark_ratio > 0.0 && current_allocated_size() + params.size() > m_max_total_allowed_size) { return false; } BufferPool& pool = *params.pool; // track buffer reuse intervals only on large pool when low watermark limit is enabled. if (m_low_watermark_ratio > 0.0 && !(pool.usage & UsageFlags::SMALL)) { for (auto& b : pool.available_buffers) { ++b->gc_count; } } auto it = pool.available_buffers.lower_bound(¶ms.search_key); if (it != pool.available_buffers.end()) { BufferBlock* buffer_block = *it; // the logic in here is simple: keep reusing existing heaps capacity as long as possible (by splitting // or releasing oversize buffers, if required), and avoid 'new' heap allocations as much as possible. if (buffer_block->size <= params.size() + kLargeHeap) { // return the existing buffer if it already fits the requested size (i.e., not oversize) params.buffer_block = buffer_block; } else { HeapBlock search_key(params.size()); // if there's an 'existing' heap with enough capacity, then don't // return the oversize buffer and sub-allocate from that existing heap. if (pool.heaps.lower_bound(&search_key) != pool.heaps.end()) { params.buffer_block = nullptr; } else if (buffer_block->retainCount() <= 1) { // otherwise if buffer is releasable immediately, we make room by releasing the // buffer and reuse the new space within its heap container for the new smaller buffer allocation release_buffer(buffer_block, false); // this will skip unnecessary garbage collection as we'll reuse the newly released space params.has_memory_pressure = false; } else if (params.has_memory_pressure) { // the oversized buffer is busy and not reusable at the moment. So release it (and potentially its heap // container) in allocator, and ARC will later free up its backing memory when the busy command buffer finishes. release_buffer(buffer_block, true); } else { // only if there's no memory pressure, we'll reuse the oversized buffer params.buffer_block = buffer_block; } } } if (!params.buffer_block) { return false; // this will make allocator to allocate a new buffer } pool.available_buffers.erase(params.buffer_block); params.buffer_block->requested_size = params.requested_size; params.buffer_block->gc_count = 0; pool.available_size -= params.buffer_block->size; if ((m_debug_verbosity & DebugVerbosity::RECYCLES) && (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { std::cerr << "Reusing " << ((params.pool->usage & UsageFlags::SHARED) ? "shared" : "private") << ((params.pool->usage & UsageFlags::SCALAR) ? " scalar" : "") << " buffer #" << params.buffer_block->buf_id << " of size " << format_size(params.buffer_block->size) << " at " << params.buffer_block->buffer << " (requested: " << format_size(params.requested_size) << ", use#: " << params.buffer_block->use_count + 1 << ", retain#: " << params.buffer_block->retainCount() << ")\n"; } return true; } BufferBlock* MPSHeapAllocatorImpl::alloc_buffer_block(size_t size, uint32_t usage) { TORCH_CHECK(size < m_max_buffer_size, "Invalid buffer size: ", format_size(size)); size_t alloc_size = get_allocation_size(size, usage); auto& pool = get_pool(size, alloc_size, usage); AllocParams params(alloc_size, size, &pool); // we care about memory pressure if only we're allocating large buffers when the // low watermark limit has been reached params.has_memory_pressure = !(pool.usage & UsageFlags::SMALL) && getLowWatermarkValue() <= 0; params.has_unified_memory = m_device.hasUnifiedMemory; // first, try to get a block from the existing pool. bool block_found = get_free_buffer(params); if (!block_found) { // do garbage collection if memory pressure is high and there's enough memory in pool if (params.has_memory_pressure && alloc_size < pool.available_size) { garbage_collect_cached_buffers(params); } block_found = // Attempt allocate alloc_buffer(params) || // Callbacks might release more memory (eg. by forcing a GC in the host language) thus // we can retry getting a free buffer in the pool, before trying to alloc again. (trigger_memory_callbacks(nullptr, IMpsAllocatorCallback::EventType::ALLOCATION_FAILED) && get_free_buffer(params)) || // Free enough available cached blocks to satisfy alloc and retry alloc. (release_available_cached_buffers(params) && alloc_buffer(params)) || // Free all cached buffers and retry alloc. (release_cached_buffers() && alloc_buffer(params)); } BufferBlock* buffer_block = params.buffer_block; // the OOM could be triggered if: // 1- the High Watermark limit has been reached (if enabled) // 2- ran out of device memory, or the memory fragmentation is so high that a contiguous // chunk of requested size couldn't be found. if (!block_found || !buffer_block) { if (m_high_watermark_ratio > 0.0) { TORCH_CHECK( false, "MPS backend out of memory (MPS allocated: ", format_size(m_total_allocated_memory), ", other allocations: ", format_size(current_allocated_size() - m_total_allocated_memory), ", max allowed: ", format_size(m_max_total_allowed_size), "). Tried to allocate ", format_size(alloc_size), " on ", ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"), " pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure)."); } else { TORCH_CHECK(false, "MPS backend out of memory (MPS allocated: ", format_size(m_total_allocated_memory), ", other allocations: ", format_size(current_allocated_size() - m_total_allocated_memory), "). Tried to allocate ", format_size(alloc_size), " on ", ((pool.usage & UsageFlags::SHARED) ? "shared" : "private"), " pool."); } } buffer_block->in_use = true; buffer_block->use_count++; m_current_allocated_memory += buffer_block->size; return buffer_block; } void MPSHeapAllocatorImpl::free_buffer(BufferBlock* buffer_block) { TORCH_INTERNAL_ASSERT(buffer_block->in_use); BufferPool& pool = *buffer_block->heap->pool; // Makes sure the BufferBlock* isn't already present in the pool we're freeing it back into. TORCH_INTERNAL_ASSERT(pool.available_buffers.insert(buffer_block).second); pool.available_size += buffer_block->size; buffer_block->shape.clear(); // reset shape TORCH_INTERNAL_ASSERT_DEBUG_ONLY(m_current_allocated_memory >= buffer_block->size); m_current_allocated_memory -= buffer_block->size; if (buffer_block->event) { // returns the MPSEvent back to MPSEventPool buffer_block->event.reset(nullptr); } buffer_block->in_use = false; } BufferBlock* MPSHeapAllocatorImpl::get_allocated_buffer_block(const void* ptr) { auto it = m_allocated_buffers.find(ptr); if (it == m_allocated_buffers.end()) { return nullptr; } return it->second; } bool MPSHeapAllocatorImpl::release_buffer(BufferBlock* buffer_block, bool remove_empty_heap) { HeapBlock* heap_block = buffer_block->heap; BufferPool& pool = *heap_block->pool; pool.allocated_size -= buffer_block->size; pool.available_size -= buffer_block->size; m_allocated_buffers.erase(buffer_block->buffer); pool.available_buffers.erase(buffer_block); pool.n_buffers--; // will re-insert later to keep the heaps list sorted based on heap's new available size (if heap not empty) pool.heaps.erase(heap_block); uint32_t retainCount = heap_block->releaseMTLBuffer(buffer_block->buffer); if ((m_debug_verbosity & DebugVerbosity::RELEASES) && (!(m_debug_verbosity & DebugVerbosity::LARGE_ONLY) || !(pool.usage & UsageFlags::SMALL))) { std::cerr << "Released buffer #" << buffer_block->buf_id << " of size " << format_size(buffer_block->size) << " from heap #" << heap_block->heap_id << " (heap size: " << format_size(heap_block->size.available) << ", use#: " << buffer_block->use_count << ", retain#: " << retainCount << ", gc#: " << buffer_block->gc_count << ")\n"; } delete buffer_block; if (remove_empty_heap && heap_block->n_buffers == 0) { pool.heaps_pending_update.erase(heap_block); m_total_allocated_memory -= heap_block->size.total; retainCount = heap_block->releaseMTLHeap(); if (m_debug_verbosity & DebugVerbosity::RELEASES) { std::cerr << "Released heap #" << heap_block->heap_id << " of size " << format_size(heap_block->size.total) << " (current allocated: " << format_size(current_allocated_size()) << ", retain#: " << retainCount << ")\n"; } delete heap_block; return true; } else { pool.heaps.insert(heap_block); // if heap wasn't released and its released buffer is still busy in command buffer, the available // size of the heap cannot be updated and we should defer updating until command buffer finishes. if (retainCount > 1) { pool.heaps_pending_update.insert(heap_block); m_mutex.unlock(); m_stream->addCompletedHandler(^(id) { std::lock_guard lock(m_mutex); // check if the heap block still exists if (pool.heaps_pending_update.find(heap_block) != pool.heaps_pending_update.end()) { pool.heaps_pending_update.erase(heap_block); pool.heaps.erase(heap_block); heap_block->updateAvailableSize(); pool.heaps.insert(heap_block); } }); m_mutex.lock(); } } return false; } void MPSHeapAllocatorImpl::release_buffers(BufferPool& pool) { if (pool.available_buffers.empty()) { return; } if ((m_debug_verbosity & DebugVerbosity::RELEASES)) { std::cerr << "Releasing " << pool.available_buffers.size() << " buffers from " << ((pool.usage & UsageFlags::SMALL) ? "small " : "large ") << ((pool.usage & UsageFlags::SHARED) ? "shared" : "private") << ((pool.usage & UsageFlags::SCALAR) ? " scalar" : "") << " pool (total size: " << format_size(pool.allocated_size) << ", #buffers: " << pool.n_buffers << ")\n"; } auto it = pool.available_buffers.begin(); while (it != pool.available_buffers.end()) { BufferBlock* buffer_block = *it; ++it; release_buffer(buffer_block); } } bool MPSHeapAllocatorImpl::release_available_cached_buffers(AllocParams& params) { BufferPool& pool = *params.pool; if (pool.available_buffers.empty()) { return false; } auto it = pool.available_buffers.lower_bound(¶ms.search_key); if (it == pool.available_buffers.end()) { size_t totalReleased = 0; --it; while (totalReleased < params.search_key.size) { auto cur = it; totalReleased += (*it)->size; if (it != pool.available_buffers.begin()) { --it; release_buffer(*cur); } else { release_buffer(*cur); break; } } if (totalReleased < params.search_key.size) { return false; } } else { release_buffer(*it); } return true; } bool MPSHeapAllocatorImpl::release_cached_buffers() { if (m_debug_verbosity >= DebugVerbosity::PROFILING) { std::cerr << "Attempting to release cached buffers (MPS allocated: " << format_size(m_total_allocated_memory) << ", other allocations: " << format_size(current_allocated_size() - m_total_allocated_memory) << ")\n"; } // before releasing the buffers make sure the command buffer has finished. // we need to release the lock temporarily as synchronizing may cause deadlock with completion handlers. m_mutex.unlock(); auto stream = getDefaultMPSStream(); dispatch_sync(stream->queue(), ^() { stream->synchronize(SyncType::COMMIT_AND_WAIT); }); m_mutex.lock(); // Free all cached blocks to system allocator for (const auto& poolIt : m_pools) { BufferPool& pool = *poolIt.second; release_buffers(pool); } return true; } void MPSHeapAllocatorImpl::garbage_collect_cached_buffers(AllocParams& params) { // skip garbage collection if memory pressure has already relieved if (current_allocated_size() < m_low_watermark_limit) { return; } // attempt to collect garbage until we reach below low watermark limit const auto target_size = current_allocated_size() - m_low_watermark_limit; const BufferPool& pool = *params.pool; // calculate the total age of the free-able blocks. We'll use it later to get the average age threshold. double total_age = 0.0; unsigned int freeable_block_count = 0, freed_count = 0; size_t gc_reclaimed = 0; for (auto& b : pool.available_buffers) { if (b->retainCount() <= 1) { total_age += b->gc_count; ++freeable_block_count; } } if (freeable_block_count == 0) { return; } // repeat GC until we reach reclaim > target size. bool block_freed = true; while (gc_reclaimed < target_size && block_freed && freeable_block_count > 0) { // free blocks exceeding this age threshold first. double age_threshold = total_age / freeable_block_count; // stop iteration if we can no longer free a block. block_freed = false; // free blocks of > avg age. Stop garbage collection if we reach below the // low watermark limit since re-allocation or fragmentation could be costly. auto it = pool.available_buffers.begin(); while (it != pool.available_buffers.end() && gc_reclaimed < target_size) { BufferBlock* buffer_block = *it++; if (buffer_block->gc_count >= age_threshold && buffer_block->retainCount() <= 1) { block_freed = true; gc_reclaimed += buffer_block->size; total_age -= buffer_block->gc_count; freeable_block_count--; freed_count++; release_buffer(buffer_block, !buffer_block->heap->is_split); } } } if (m_debug_verbosity & DebugVerbosity::RELEASES) { std::cerr << "Garbage collected " << freed_count << " buffers from large " << ((pool.usage & UsageFlags::SHARED) ? "shared" : "private") << " pool (total reclaimed: " << format_size(gc_reclaimed) << ", #buffers: " << pool.available_buffers.size() << ")\n"; } } // public interface to MPSAllocator id MPSHeapAllocatorImpl::malloc(size_t size, uint32_t usage) { std::lock_guard lock(m_mutex); BufferBlock* buffer_block = alloc_buffer_block(size, usage); return buffer_block ? buffer_block->buffer : nullptr; } bool MPSHeapAllocatorImpl::isSharedBuffer(const void* ptr) { std::lock_guard lock(m_mutex); BufferBlock* buffer_block = get_allocated_buffer_block(ptr); // it's OK for the buffer_block to not exist yet return buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED); } id MPSHeapAllocatorImpl::allocScalarBufferWithValue(void* value, size_t size) { BufferBlock* buffer_block = nullptr; { std::lock_guard lock(m_mutex); buffer_block = alloc_buffer_block(size, UsageFlags::SCALAR); if (!buffer_block) { return nullptr; } if (!buffer_block->cpu_ptr) { buffer_block->cpu_ptr = [buffer_block->buffer contents]; } } // buffer is out of the pool, so no mutex lock is needed memcpy(buffer_block->cpu_ptr, value, size); return buffer_block->buffer; } std::pair MPSHeapAllocatorImpl::getSharedBufferPtr(const void* ptr) { std::lock_guard lock(m_mutex); BufferBlock* buffer_block = get_allocated_buffer_block(ptr); // return if buffer was not allocated on MPSAllocator or isn't a Shared buffer if (!buffer_block || !(buffer_block->heap->pool->usage & UsageFlags::SHARED)) { return {nullptr, 0}; } if (!buffer_block->cpu_ptr) { buffer_block->cpu_ptr = [buffer_block->buffer contents]; } return {buffer_block->cpu_ptr, buffer_block->retainCount()}; } bool MPSHeapAllocatorImpl::recordEvents(c10::ArrayRef buffers) { bool recordedEvent = false; std::lock_guard lock(m_mutex); for (const auto& buffer : buffers) { BufferBlock* buffer_block = get_allocated_buffer_block(buffer); // return if buffer was not allocated on MPSAllocator or isn't a Shared buffer if (buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED)) { if (!buffer_block->event) { buffer_block->event = m_event_pool->acquireEvent(false, nullptr); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(buffer_block->event); } buffer_block->event->record(/*needsLock*/ false); recordedEvent = true; } } return recordedEvent; } bool MPSHeapAllocatorImpl::waitForEvents(c10::ArrayRef buffers) { std::vector buffer_blocks; { std::lock_guard lock(m_mutex); for (const auto& buffer : buffers) { BufferBlock* buffer_block = get_allocated_buffer_block(buffer); // wait on event if "shared" buffer was allocated on MPSAllocator and // or actually needs waiting (based on retainCount) if (buffer_block && (buffer_block->heap->pool->usage & UsageFlags::SHARED) && buffer_block->retainCount() > 1 && buffer_block->event) { buffer_blocks.push_back(buffer_block); } } } bool waitedForEvent = false; for (const auto& buffer_block : buffer_blocks) { // check for retain count again as the previous wait might have released the buffer if (buffer_block->retainCount() > 1) { bool waitedOnCPU = buffer_block->event->synchronize(); if (waitedOnCPU) { // after waiting, it's a good time to free some pending inactive buffers freeInactiveBuffers(); waitedForEvent |= buffer_block->retainCount() <= 1; } else { // even if one of the buffers weren't recorded beforehand, we return // without continuing with other buffers since retainCount > 1 waitedForEvent = false; break; } } } return waitedForEvent; } id_t MPSHeapAllocatorImpl::getBufferId(const void* ptr) { std::lock_guard lock(m_mutex); BufferBlock* buffer_block = get_allocated_buffer_block(ptr); return buffer_block ? buffer_block->buf_id : 0; } ssize_t MPSHeapAllocatorImpl::getUnalignedBufferSize(const void* ptr) { std::lock_guard lock(m_mutex); BufferBlock* buffer_block = get_allocated_buffer_block(ptr); if (buffer_block) { return (ssize_t)buffer_block->requested_size; } // -1 indicates the passed buffer pointer wasn't found return -1; } void MPSHeapAllocatorImpl::setBufferShape(const void* ptr, const IntArrayRef& shape) { std::lock_guard lock(m_mutex); BufferBlock* buffer_block = get_allocated_buffer_block(ptr); TORCH_INTERNAL_ASSERT(buffer_block, "failed to find the buffer ", ptr); // note that the IntArrayRef doesn't own the underlying data, and the backing // memory for shape data must persist as long as the buffer is in use. // So we need to copy to vector. buffer_block->shape = shape.vec(); } IntArrayRef MPSHeapAllocatorImpl::getBufferShape(const void* ptr) { std::lock_guard lock(m_mutex); BufferBlock* buffer_block = get_allocated_buffer_block(ptr); if (buffer_block && buffer_block->shape.size() > 0) { return IntArrayRef{buffer_block->shape}; } return IntArrayRef(); } void MPSHeapAllocatorImpl::free(void* ptr) { BufferBlock* buffer_block = nullptr; { std::lock_guard lock(m_mutex); buffer_block = get_allocated_buffer_block(ptr); TORCH_INTERNAL_ASSERT(buffer_block); const BufferPool& pool = *buffer_block->heap->pool; if (!(pool.usage & UsageFlags::SCALAR)) { free_buffer(buffer_block); return; } } // we sync the scalar pool manually with completion handler at the time buffer is // freed when the MPSScalar instance goes our of scope m_stream->addCompletedHandler(^(id) { std::lock_guard lock(m_mutex); free_buffer(buffer_block); }); } void MPSHeapAllocatorImpl::freeInactiveBuffers() { std::lock_guard lock(m_mutex); for (const auto& poolIt : m_pools) { BufferPool& pool = *poolIt.second; if (!pool.buffers_pending_free.empty()) { for (auto it = pool.buffers_pending_free.begin(), last = pool.buffers_pending_free.end(); it != last;) { BufferBlock* buffer_block = *it; if (buffer_block->retainCount() <= 1) { it = pool.buffers_pending_free.erase(it); free_buffer(buffer_block); } else { ++it; } } } } } void MPSHeapAllocatorImpl::emptyCache() { std::lock_guard lock(m_mutex); release_cached_buffers(); } ssize_t MPSHeapAllocatorImpl::getLowWatermarkValue() { // check if low watermark limit is disabled if (m_low_watermark_ratio == 0.0) { return std::numeric_limits::max(); } // current_allocated_size could exceed m_low_watermark_limit (e.g., when swapping to disk) return std::max(0, (ssize_t)(m_low_watermark_limit - current_allocated_size()) / 1048576L); } inline std::string MPSHeapAllocatorImpl::format_size(uint64_t size) const { std::ostringstream os; os.precision(2); os << std::fixed; if (size <= 1024UL) { os << size << " bytes"; } else if (size <= 1048576UL) { os << ((float)size / 1024.0) << " KB"; } else if (size <= 1073741824UL) { os << ((float)size / 1048576.0) << " MB"; } else { os << ((float)size / 1073741824.0) << " GB"; } return os.str(); } } // namespace HeapAllocator // Use "at::mps::GetMPSAllocator()" to acquire a handle to MPS Allocator namespace { HeapAllocator::MPSHeapAllocatorImpl& _getAllocImpl() { static HeapAllocator::MPSHeapAllocatorImpl s_allocatorImpl; return s_allocatorImpl; } } // namespace // MPS allocator struct to be registered with Pytorch struct TORCH_API MPSAllocator final : public IMPSAllocator { public: explicit MPSAllocator(uint32_t Usage) : m_has_unified_memory(_getAllocImpl().Device().hasUnifiedMemory), m_usage(Usage) { if (_getAllocImpl().getDebugVerbosity()) { if (!(m_usage & HeapAllocator::UsageFlags::SHARED) || m_has_unified_memory) { std::cerr << "Initializing " << ((m_usage & HeapAllocator::UsageFlags::SHARED) ? "shared" : "private") << " heap allocator on " << (m_has_unified_memory ? "unified" : "discrete") << " device memory of size " << _getAllocImpl().format_size(_getAllocImpl().Device().recommendedMaxWorkingSetSize) << "\n"; } } } ~MPSAllocator() override { _getAllocImpl().emptyCache(); } DeleterFnPtr raw_deleter() const override { return &Delete; } DataPtr allocate(const size_t nbytes) override { __block id buf = nbytes > 0 ? _getAllocImpl().malloc(nbytes, m_usage) : nullptr; return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)}; } // implementation of IMPSAllocator interface DataPtr allocScalarBufferWithValue(void* value, size_t size) const override { id buf = _getAllocImpl().allocScalarBufferWithValue(value, size); return {buf, buf, &Delete, at::Device(at::DeviceType::MPS, 0)}; } std::pair getSharedBufferPtr(const void* ptr) const override { return _getAllocImpl().getSharedBufferPtr(ptr); } bool isSharedBuffer(const void* ptr) const override { return _getAllocImpl().isSharedBuffer(ptr); } bool isSharedStorageSupported() const override { return m_has_unified_memory; } void emptyCache() const override { _getAllocImpl().emptyCache(); } void freeInactiveBuffers() const override { _getAllocImpl().freeInactiveBuffers(); } ssize_t getUnalignedBufferSize(const void* ptr) const override { return _getAllocImpl().getUnalignedBufferSize(ptr); } id_t getBufferId(const void* ptr) const override { return _getAllocImpl().getBufferId(ptr); }; IntArrayRef getBufferShape(const void* ptr) const override { return _getAllocImpl().getBufferShape(ptr); } void setBufferShape(const void* ptr, const IntArrayRef& shape) const override { _getAllocImpl().setBufferShape(ptr, shape); } size_t getTotalAllocatedMemory() const override { return _getAllocImpl().getTotalAllocatedMemory(); } size_t getCurrentAllocatedMemory() const override { return _getAllocImpl().getCurrentAllocatedMemory(); } size_t getDriverAllocatedMemory() const override { return _getAllocImpl().getDriverAllocatedMemory(); } size_t getRecommendedMaxMemory() const override { return _getAllocImpl().getRecommendedMaxMemory(); } ssize_t getLowWatermarkValue() const override { return _getAllocImpl().getLowWatermarkValue(); } size_t getLowWatermarkLimit() const override { return _getAllocImpl().getLowWatermarkLimit(); } size_t getHighWatermarkLimit() const override { return _getAllocImpl().getHighWatermarkLimit(); } void setLowWatermarkRatio(double ratio) const override { _getAllocImpl().setLowWatermarkRatio(ratio); } void setHighWatermarkRatio(double ratio) const override { _getAllocImpl().setHighWatermarkRatio(ratio); } bool recordEvents(c10::ArrayRef buffers) const override { return _getAllocImpl().recordEvents(buffers); } bool waitForEvents(c10::ArrayRef buffers) const override { return _getAllocImpl().waitForEvents(buffers); } std::string formatSize(size_t size) const override { return _getAllocImpl().format_size(size); } void copy_data(void* dest, const void* src, std::size_t count) const final { default_copy_data(dest, src, count); } private: bool m_has_unified_memory; uint32_t m_usage; static void Delete(void* ptr) { if (ptr) { _getAllocImpl().free(ptr); } } }; namespace { MPSAllocator& _getSharedAllocator() { static MPSAllocator s_mps_shared_alloc(HeapAllocator::UsageFlags::SHARED); return s_mps_shared_alloc; } MPSAllocator& _getPrivateAllocator() { static MPSAllocator s_mps_private_alloc(HeapAllocator::UsageFlags::PRIVATE); return s_mps_private_alloc; } } // anonymous namespace IMPSAllocator* getIMPSAllocator(bool sharedAllocator) { if (!sharedAllocator) { return &_getPrivateAllocator(); } auto& sa = _getSharedAllocator(); if (sa.isSharedStorageSupported()) { return &sa; } return nullptr; } // torch.is_pinned() implementation // Pinned memory will be helpful on Apple Silicon Macs with Unified memory as we // will be able to use SharedStorageMode for MTLBuffer allocations. This will // avoid extra copies on DataLoading operations. bool isMPSPinnedPtr(const void* data) { return at::mps::_getSharedAllocator().isSharedBuffer(data); } } // namespace at::mps