// Copyright © 2023 Apple Inc. #include namespace at::mps { MPSEvent::MPSEvent(id_t ID, MPSStream* stream, bool enable_timing) : m_id(ID), m_enable_timing(enable_timing), m_stream(stream), m_event([stream->device() newSharedEvent]) {} MPSEvent::~MPSEvent() { if (m_event) { [m_event release]; m_event = nil; } if (m_listener) { [m_listener release]; m_listener = nil; } } void MPSEvent::recordLocked(bool syncEvent) { // active encoders must end before encoding or waiting m_stream->endKernelCoalescing(); ++m_signalCounter; if (m_enable_timing) { notifyLocked(^(id, uint64_t) { m_completion_time = getTime(); notifyCpuSync(); }); } id commandBuffer = m_stream->commandBuffer(); [commandBuffer encodeSignalEvent:m_event value:m_signalCounter]; if (syncEvent) { m_stream->synchronize(SyncType::COMMIT); } } bool MPSEvent::waitLocked(bool syncEvent) { // check if event is not recorded yet if (m_event.signaledValue >= m_signalCounter) { return false; } // active encoders must end before encoding or waiting m_stream->endKernelCoalescing(); id commandBuffer = m_stream->commandBuffer(); [commandBuffer encodeWaitForEvent:m_event value:m_signalCounter]; if (syncEvent) { m_stream->synchronize(SyncType::COMMIT); } return true; } bool MPSEvent::notifyLocked(MTLSharedEventNotificationBlock block) { // check if event is not recorded yet if (m_event.signaledValue >= m_signalCounter) { return false; } if (!m_listener) { m_listener = [[MTLSharedEventListener alloc] init]; } [m_event notifyListener:m_listener atValue:m_signalCounter block:block]; return true; } void MPSEvent::record(bool needsLock, bool syncEvent) { if (!needsLock) { recordLocked(syncEvent); return; } dispatch_sync(m_stream->queue(), ^() { @autoreleasepool { recordLocked(syncEvent); } }); } bool MPSEvent::wait(bool needsLock, bool syncEvent) { __block bool waited = false; if (!needsLock) { return waitLocked(syncEvent); } dispatch_sync(m_stream->queue(), ^() { @autoreleasepool { waited = waitLocked(syncEvent); } }); return waited; } bool MPSEvent::notify(bool needsLock, MTLSharedEventNotificationBlock block) { if (!needsLock) { return notifyLocked(block); } __block bool scheduledNotify = false; dispatch_sync(m_stream->queue(), ^() { @autoreleasepool { scheduledNotify = notifyLocked(block); } }); return scheduledNotify; } void MPSEvent::notifyCpuSync() { std::lock_guard lock(m_cpu_sync_mutex); m_cpu_sync_completed = true; m_cpu_sync_cv.notify_one(); } void MPSEvent::waitForCpuSync() { std::unique_lock lock(m_cpu_sync_mutex); m_cpu_sync_cv.wait(lock, [&] { return m_cpu_sync_completed; }); m_cpu_sync_completed = false; } bool MPSEvent::synchronize() { bool scheduledNotify = notifyLocked(^(id, uint64_t) { m_completion_time = getTime(); notifyCpuSync(); }); if (scheduledNotify) { waitForCpuSync(); return true; } return false; } bool MPSEvent::query() const { // return false if not recorded or signaled yet return m_signalCounter && (m_event.signaledValue >= m_signalCounter); } void MPSEvent::reset(MPSStream* stream, bool enable_timing) { if (stream != m_stream) { m_signalCounter = 0; m_event.signaledValue = 0; m_stream = stream; } // reset record time m_completion_time = 0; m_enable_timing = enable_timing; m_cpu_sync_completed = false; }; //----------------------------------------------------------------- // MPSEventPool //----------------------------------------------------------------- MPSEventPool::MPSEventPool(MPSStream* default_stream) : m_default_stream(default_stream) { // default deleter to return the event back to pool after it gets released m_default_deleter = [&](MPSEvent* event) { std::lock_guard lock(m_mutex); m_pool.push(std::unique_ptr(event)); }; } MPSEventPool::~MPSEventPool() { emptyCache(); } MPSEventPtr MPSEventPool::acquireEvent(bool enable_timing, MPSStream* stream) { if (!stream) { stream = m_default_stream; } { std::lock_guard lock(m_mutex); if (!m_pool.empty()) { auto event = m_pool.top().release(); m_pool.pop(); event->reset(stream, enable_timing); return MPSEventPtr(event, m_default_deleter); } } auto new_event = std::make_unique(++m_event_counter, stream, enable_timing); return MPSEventPtr(new_event.release(), m_default_deleter); } void MPSEventPool::emptyCache() { std::lock_guard lock(m_mutex); while (!m_pool.empty()) { m_pool.pop(); } } id_t MPSEventPool::acquireEvent(bool enable_timing) { std::lock_guard lock(m_mutex); MPSEventPtr event = acquireEvent(enable_timing, nullptr); TORCH_INTERNAL_ASSERT(event); id_t event_id = event->getID(); m_in_use_events.emplace(event_id, std::move(event)); return event_id; } void MPSEventPool::releaseEvent(id_t event_id) { std::lock_guard lock(m_mutex); TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id); // returns the event back to the MPSEventPool m_in_use_events.erase(event_id); } void MPSEventPool::recordEvent(id_t event_id, bool syncEvent) { MPSEvent* event = getInUseEvent(event_id); event->record(/*needsLock*/ true, syncEvent); } void MPSEventPool::waitForEvent(id_t event_id, bool syncEvent) { MPSEvent* event = getInUseEvent(event_id); event->wait(/*needsLock*/ true, syncEvent); } void MPSEventPool::synchronizeEvent(id_t event_id) { MPSEvent* event = getInUseEvent(event_id); event->synchronize(); } bool MPSEventPool::queryEvent(id_t event_id) { MPSEvent* event = getInUseEvent(event_id); return event->query(); } double MPSEventPool::elapsedTime(id_t start_event_id, id_t end_event_id) { // first make sure notifyListeners are called to capture events' completion times dispatch_sync(m_default_stream->queue(), ^() { m_default_stream->synchronize(SyncType::COMMIT_AND_WAIT); }); std::lock_guard lock(m_mutex); MPSEvent* start_event = getInUseEvent(start_event_id, false); MPSEvent* end_event = getInUseEvent(end_event_id, false); // the notify is called on a separate thread, so this waits for that end_event->waitForCpuSync(); const uint64_t start_time = start_event->getCompletionTime(); const uint64_t end_time = end_event->getCompletionTime(); TORCH_CHECK(start_time > 0 && end_time > 0, "Events were not created with argument 'enable_timing=True'"); TORCH_CHECK( end_time > start_time, "End event ", end_event_id, " was not recorded after start event ", start_event_id); return double(end_time - start_time) * 1e-6; } MPSEvent* MPSEventPool::getInUseEvent(id_t event_id, bool locked) { if (locked) { m_mutex.lock(); } TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id); MPSEvent* event = m_in_use_events[event_id].get(); if (locked) { m_mutex.unlock(); } return event; } std::shared_ptr getMPSEventPool() { static std::shared_ptr event_pool = std::make_shared(getDefaultMPSStream()); return event_pool; } } // namespace at::mps