#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace torch::autograd::profiler { // We decompose the profiler logic into the following components: // // ThreadLocalDebugInfo: // // ThreadLocalDebugInfo is a thread local mapping from slots into // the debug information structs. // ThreadLocalDebugInfo is automatically propagated across thread // boundaries, including the cases of: // - launching async jobs with at::launch // - executing JIT continuations // - moving from the forward threads into autograd (backward) threads // // Entries in ThreadLocalDebugInfo are managed by DebugInfoGuard // which can be used to add or overwrite an entry in the thread local // mapping. A corresponding entry is removed when the guard is destroyed, // potentially revealing the previously set value for the same slot. // // For the async tasks, slots previously set in the main thread before // launching of an async task are shared and visible in the async task. // // On the other hand, any adding or overwriting of the mapping by the // async task is not visible to the main thread and any modification // (including removal of the entries) in the main thread is not visible // to the async task if it happens after launching the task. // // We use ThreadLocalDebugInfo (slot PROFILER_STATE) to store profiler config, // as well as a list of events that happen during profiling. // An instance of ThreadLocalDebugInfo is created each time we enter // profiler (i.e. enter profiling context manager/call enableConfig) and // uniquely identifies a profiling run. // // We automatically propagate ThreadLocalDebugInfo into async tasks, // as well as across JIT continuations and autograd thread, so all // the operations that happen between profiling start and end // (not necessarily within the same thread) are recorded. // Unless the profiling slot is overwritten as in the case of nested // profiling ranges (in this case events for the subrange are handled // by the nested profiler) // // When we exit a profiling range (either by exiting profiling context // manager or by calling disableProfiler), we remove the previously set // profiling entry for the given thread local mapping, and consolidate // events in the profiling result // // // ThreadLocalState: // // ThreadLocalState takes a 'snapshot' of thread local variables // using provided getters. It is used together with ThreadLocalStateGuard // to transfer the snapshot across thread boundary and set the thread local // values as in the parent task. // // Profiler uses ThreadLocalState to propagate profiler's thread local state. // ThreadLocalState also automatically propagates profiler callbacks. // // // at::RecordFunction and observers // // Profiler uses observers mechanism to add a pair of thread local callbacks // that are executed on a number of predetermined ranges, including: // - c10/ATen ops // - TorchScript functions/methods // - user defined named ranges (see `record_function` python context manager) // // Profiler setups a pair of callbacks that record profiling events and save // them into the thread local profiler struct (ThreadLocalDebugInfo, // PROFILER_STATE slot) // // // Thus, the overall logic is: // // enableProfiler: // - checks that profiler is not enabled (otherwise throws) // - pushes new ThreadLocalDebugInfo (slot PROFILER_STATE) as the profiler // config for the current thread // - pushes profiling callbacks for the current thread // // disableProfiler: // - pops PROFILER_STATE slot from the current ThreadLocalDebugInfo and // consolidates events // - removes profiling callbacks // // ThreadLocalState: // - propagates ThreadLocalDebugInfo across threads // - propagates profiler callbacks across threads // // Profiler callbacks: // - get the current profiling state (PROFILER slot in ThreadLocalDebugInfo) // - save profiling events into the profiling state // namespace { using torch::profiler::impl::ActiveProfilerType; using torch::profiler::impl::ProfilerStateBase; struct ProfilerLegacyThreadLocalState : public ProfilerStateBase { explicit ProfilerLegacyThreadLocalState( const torch::profiler::impl::ProfilerConfig& config) : ProfilerStateBase(config), remoteProfiledEvents_{std::nullopt} {} ~ProfilerLegacyThreadLocalState() override = default; static ProfilerLegacyThreadLocalState* getTLS() { auto tls = ProfilerStateBase::get(/*global=*/false); TORCH_INTERNAL_ASSERT_DEBUG_ONLY( tls == nullptr || tls->profilerType() == ActiveProfilerType::LEGACY); return static_cast(tls); } thread_event_lists consolidate(); void mark(std::string name, bool include_cuda = true); void setOrAddRemoteProfiledEvents( std::vector&& remoteProfiledEvents); void pushRange( const at::RecordFunction& fn, const bool record_cuda, std::vector>&& shapes = {}); void popRange(const at::RecordFunction& fn, const bool record_cuda); void reportMemoryUsage( void* /* unused */, int64_t alloc_size, size_t /* total_allocated, unused for legacy */, size_t /* total_reserved, unused for legacy */, c10::Device device) override; ActiveProfilerType profilerType() override { return ActiveProfilerType::LEGACY; } void leakHandle() { handle_ = 0; } protected: RangeEventList& getEventList( std::optional thread_id = std::nullopt); std::mutex state_mutex_; std::unordered_map> event_lists_map_; std::optional>> remoteProfiledEvents_; }; thread_event_lists ProfilerLegacyThreadLocalState::consolidate() { std::lock_guard g(state_mutex_); thread_event_lists result; for (auto& kv : event_lists_map_) { auto& list = kv.second; result.emplace_back(list->consolidate()); } // Consolidate remote events if applicable as well. if (remoteProfiledEvents_) { result.insert( result.end(), std::make_move_iterator(remoteProfiledEvents_->begin()), std::make_move_iterator(remoteProfiledEvents_->end())); } return result; } void ProfilerLegacyThreadLocalState::mark(std::string name, bool include_cuda) { if (config_.disabled()) { return; } if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { torch::profiler::impl::cudaStubs()->mark(name.c_str()); } else { LegacyEvent evt( EventKind::Mark, at::StringView(std::move(name)), at::RecordFunction::currentThreadId(), include_cuda && config_.state == torch::profiler::impl::ProfilerState::CUDA); evt.setNodeId(at::RecordFunction::getDefaultNodeId()); getEventList().record(std::move(evt)); } } void ProfilerLegacyThreadLocalState::setOrAddRemoteProfiledEvents( std::vector&& remoteProfiledEvents) { // Lock to serialize access from multiple callback threads. std::lock_guard guard(state_mutex_); if (remoteProfiledEvents_) { (*remoteProfiledEvents_).emplace_back(remoteProfiledEvents); } else { remoteProfiledEvents_ = {std::move(remoteProfiledEvents)}; } } void ProfilerLegacyThreadLocalState::pushRange( const at::RecordFunction& fn, const bool record_cuda, std::vector>&& shapes) { if (config_.disabled()) { return; } if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { torch::profiler::impl::cudaStubs()->rangePush( torch::profiler::impl::getNvtxStr(fn.name(), fn.seqNr(), shapes) .c_str()); } else { LegacyEvent evt( EventKind::PushRange, at::StringView(std::string(fn.name())), at::RecordFunction::currentThreadId(), record_cuda, fn.handle(), std::move(shapes), at::RecordFunction::getDefaultNodeId(), fn.isAsync()); evt.setSequenceNr(fn.seqNr()); evt.setFwdThreadId(fn.forwardThreadId()); evt.setScope((uint8_t)fn.scope()); if (config_.with_flops) { evt.setExtraArgs(torch::profiler::impl::saveExtraArgs(fn)); evt.setFlops(torch::profiler::impl::computeFlops( std::string(fn.name()), evt.extraArgs())); } // TODO: will unify the two macros BUILD_LITE_INTERPRETER and C10_MOBILE soon. #if !defined BUILD_LITE_INTERPRETER && !defined C10_MOBILE // backward nodes source range corresponds to the forward node // TODO: consider using C++ stack trace if (config_.with_stack && fn.scope() != at::RecordScope::BACKWARD_FUNCTION) { auto cs = torch::profiler::impl::prepareCallstack(jit::currentCallstack()); if (cs.empty()) { cs = torch::profiler::impl::prepareCallstack( jit::tracer::pythonCallstack()); } evt.setStack(callstackStr(cs)); } #endif getEventList().record(std::move(evt)); } } void ProfilerLegacyThreadLocalState::popRange( const at::RecordFunction& fn, const bool record_cuda) { if (config_.disabled()) { return; } if (config_.state == torch::profiler::impl::ProfilerState::NVTX) { torch::profiler::impl::cudaStubs()->rangePop(); } else { // In some cases RecordFunction (and popRange) may be // called on a different thread than pushRange // As a convention, we put the async pop on the original // thread and save current thread id in pop event LegacyEvent evt( EventKind::PopRange, at::StringView(""), at::RecordFunction::currentThreadId(), record_cuda, fn.handle()); evt.setNodeId(at::RecordFunction::getDefaultNodeId()); getEventList(fn.threadId()).record(std::move(evt)); } } void ProfilerLegacyThreadLocalState::reportMemoryUsage( void* /* unused */, int64_t alloc_size, size_t /* total_allocated, unused for legacy */, size_t /* total_reserved, unused for legacy */, c10::Device device) { if (config_.profile_memory && !config_.disabled()) { uint64_t thread_id = at::RecordFunction::currentThreadId(); LegacyEvent evt( EventKind::MemoryAlloc, at::StringView(""), thread_id, config_.state == torch::profiler::impl::ProfilerState::CUDA); evt.updateMemoryStats(alloc_size, device); getEventList(thread_id).record(std::move(evt)); } } RangeEventList& ProfilerLegacyThreadLocalState::getEventList( std::optional thread_id) { if (!thread_id.has_value()) { thread_id = at::RecordFunction::currentThreadId(); } RangeEventList* list_ptr = nullptr; std::lock_guard guard(state_mutex_); auto it = event_lists_map_.find(thread_id.value()); if (it != event_lists_map_.end()) { list_ptr = it->second.get(); } else { auto event_list = std::make_shared(); event_lists_map_[thread_id.value()] = event_list; list_ptr = event_list.get(); } return *list_ptr; } enum EventIValueIdx { KIND = 0, NAME, THREAD_ID, HANDLE, NODE_ID, CPU_MEM_USAGE, CPU_NS, CUDA_RECORDED, CUDA_MEM_USAGE, CUDA_DEVICE, CUDA_US, SHAPES, NUM_EVENT_IVALUE_IDX // must be last in list }; const std::unordered_set disable_cuda_profiling = { "aten::view", "aten::t", "aten::transpose", "aten::stride", "aten::empty", "aten::empty_like", "aten::empty_strided", "aten::as_strided", "aten::expand", "aten::resize_", "aten::squeeze", "aten::unsqueeze", "aten::slice", "aten::_unsafe_view", "aten::size"}; void pushProfilingCallbacksLegacy() { auto registration_state_ptr = ProfilerLegacyThreadLocalState::getTLS(); TORCH_INTERNAL_ASSERT(registration_state_ptr, "Expected profiler state set"); auto handle = at::addThreadLocalCallback( at::RecordFunctionCallback( [](const at::RecordFunction& fn) -> std::unique_ptr { auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); if (!state_ptr || state_ptr->config().disabled()) { return nullptr; } bool record_cuda = state_ptr->config().state == torch::profiler::impl::ProfilerState::CUDA; if (record_cuda && disable_cuda_profiling.find(fn.name()) != disable_cuda_profiling.end()) { record_cuda = false; } if (state_ptr->config().report_input_shapes) { auto sizes = torch::profiler::impl::inputSizes(fn); state_ptr->pushRange(fn, record_cuda, std::move(sizes)); } else { state_ptr->pushRange(fn, record_cuda); } return nullptr; }, [](const at::RecordFunction& fn, at::ObserverContext*) { auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); if (!state_ptr || state_ptr->config().disabled()) { return; } bool record_cuda = state_ptr->config().state == torch::profiler::impl::ProfilerState::CUDA; if (record_cuda && disable_cuda_profiling.find(fn.name()) != disable_cuda_profiling.end()) { record_cuda = false; } state_ptr->popRange(fn, record_cuda); }) .needsInputs(registration_state_ptr->config().report_input_shapes) .needsIds(true)); registration_state_ptr->setCallbackHandle(handle); } } // namespace void enableProfilerLegacy( const torch::profiler::impl::ProfilerConfig& new_config) { TORCH_CHECK( new_config.state != torch::profiler::impl::ProfilerState::NVTX || torch::profiler::impl::cudaStubs()->enabled(), "Can't use NVTX profiler - PyTorch was compiled without CUDA"); TORCH_CHECK(new_config.state != torch::profiler::impl::ProfilerState::KINETO); auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); TORCH_CHECK(!state_ptr, "Profiler is already enabled on this thread"); auto state = std::make_shared(new_config); c10::ThreadLocalDebugInfo::_push(c10::DebugInfoKind::PROFILER_STATE, state); pushProfilingCallbacksLegacy(); state->mark("__start_profile", false); } thread_event_lists disableProfilerLegacy( std::optional profilerDisableOptions) { auto cleanupTLSState = profilerDisableOptions ? profilerDisableOptions->cleanupTLSState : true; auto consolidate = profilerDisableOptions ? profilerDisableOptions->consolidate : true; // all the DebugInfoBase objects are scope based and supposed to use // DebugInfoGuard std::shared_ptr state; if (cleanupTLSState) { state = c10::ThreadLocalDebugInfo::_pop(c10::DebugInfoKind::PROFILER_STATE); } else { state = c10::ThreadLocalDebugInfo::_peek(c10::DebugInfoKind::PROFILER_STATE); } auto state_ptr = static_cast(state.get()); TORCH_CHECK( state_ptr && !state_ptr->config().disabled(), "Can't disable profiler when it's not running"); cleanupTLSState ? state_ptr->removeCallback() : state_ptr->leakHandle(); if (!consolidate || state_ptr->config().state == torch::profiler::impl::ProfilerState::NVTX) { return thread_event_lists(); } state_ptr->mark("__stop_profile", false); // Note that this will erase the underlying events. return state_ptr->consolidate(); } void addEventList(std::vector&& profiledEvents) { auto state_ptr = ProfilerLegacyThreadLocalState::getTLS(); TORCH_CHECK(state_ptr, "Profiler must be enabled."); state_ptr->setOrAddRemoteProfiledEvents(std::move(profiledEvents)); } void LegacyEvent::record(bool record_cuda) { if (record_cuda) { torch::profiler::impl::cudaStubs()->record(&device_, &cuda_event, &cpu_ns_); return; } cpu_ns_ = c10::getTime(); } /* static */ LegacyEvent LegacyEvent::fromIValue( const at::IValue& eventIValue) { TORCH_INTERNAL_ASSERT( eventIValue.isList(), "Expected IValue to contain type c10::impl::GenericList"); auto ivalues = eventIValue.toList(); TORCH_INTERNAL_ASSERT( ivalues.size() >= NUM_EVENT_IVALUE_IDX, "Expected at least ", NUM_EVENT_IVALUE_IDX, " elements to reconstruct LegacyEvent."); // Reconstruct input shapes from ivalues. const auto& shapeListIValue = ivalues.get(EventIValueIdx::SHAPES); TORCH_INTERNAL_ASSERT( shapeListIValue.isList(), "Expected profiler shapes IValue to contain type c10::impl::GenericList."); auto shapeList = shapeListIValue.toList(); std::vector> shapes; shapes.reserve(shapeList.size()); for (const auto i : c10::irange(shapeList.size())) { std::vector s; const auto& shapeIValue = shapeList.get(i); TORCH_INTERNAL_ASSERT( shapeIValue.isList(), "Expected each profiler shape element to contain shapes of type c10::impl::GenericList.") auto curShapesList = shapeIValue.toList(); s.reserve(curShapesList.size()); for (const auto j : c10::irange(curShapesList.size())) { s.emplace_back(curShapesList.get(j).toInt()); } shapes.emplace_back(s); } LegacyEvent evt( static_cast( ivalues.get(EventIValueIdx::KIND).toInt()), // EventKind at::StringView(ivalues.get(EventIValueIdx::NAME).toStringRef()), // name ivalues.get(EventIValueIdx::THREAD_ID).toInt(), // thread_id static_cast( ivalues.get(EventIValueIdx::HANDLE).toDouble()), // handle std::move(shapes), // input shapes ivalues.get(EventIValueIdx::NODE_ID).toInt(), // node id true, // is remote ivalues.get(EventIValueIdx::CPU_MEM_USAGE).toInt(), // cpu_mem_usage ivalues.get(EventIValueIdx::CPU_NS).toInt(), // cpu_ns ivalues.get(EventIValueIdx::CUDA_RECORDED).toBool(), // was cuda recorded ivalues.get(EventIValueIdx::CUDA_MEM_USAGE).toInt(), // cuda memory usage c10::DeviceIndex( ivalues.get(EventIValueIdx::CUDA_DEVICE).toInt()), // device static_cast( ivalues.get(EventIValueIdx::CUDA_US).toInt()) // cuda_us ); return evt; } at::IValue LegacyEvent::toIValue() const { c10::impl::GenericList eventIValueList(at::AnyType::get()); eventIValueList.reserve(NUM_EVENT_IVALUE_IDX); eventIValueList.emplace_back(static_cast(kind_)); eventIValueList.emplace_back(std::string(name_.str())); eventIValueList.emplace_back(static_cast(thread_id_)); eventIValueList.emplace_back(static_cast(handle_)); eventIValueList.emplace_back(node_id_); eventIValueList.emplace_back(cpu_memory_usage_); eventIValueList.emplace_back(cpu_ns_); // CUDA event information bool cuda_profiling_enabled = hasCuda(); eventIValueList.emplace_back(cuda_profiling_enabled); eventIValueList.emplace_back(static_cast(cuda_memory_usage_)); eventIValueList.emplace_back(device_); eventIValueList.emplace_back(cuda_us_); // Shapes c10::impl::GenericList shapesList = c10::impl::GenericList(at::ListType::create(at::IntType::get())); shapesList.reserve(shapes_.size()); for (const auto& shape : shapes_) { c10::impl::GenericList s = c10::impl::GenericList(at::IntType::get()); s.reserve(shape.size()); for (const auto& k : shape) { s.emplace_back(k); } shapesList.emplace_back(s); } eventIValueList.emplace_back(shapesList); return at::IValue(eventIValueList); } double LegacyEvent::cudaElapsedUs(const LegacyEvent& e) const { TORCH_CHECK(e.hasCuda() && hasCuda(), "Events were not recorded for CUDA"); TORCH_CHECK( e.device() == device(), c10::str( "Events are not on the same device: ", e.device(), " vs ", device())); if (isRemote() && e.isRemote()) { // validate that cuda_us_ has been set properly. TORCH_INTERNAL_ASSERT(cuda_us_ >= 0 && e.cuda_us_ >= 0); return static_cast(e.cuda_us_ - cuda_us_); } return torch::profiler::impl::cudaStubs()->elapsed( &cuda_event, &e.cuda_event); } static const at::jit::CodeTemplate event_template(R"( { "name": "${name}", "ph": "X", "ts": ${ts}, "dur": ${dur}, "tid": ${tid}, "pid": "CPU Functions", "args": {} })"); void writeProfilerEventsToStream( std::ostream& out, const std::vector& events) { TORCH_CHECK(out, "Could not open file"); LegacyEvent* profiler_start = nullptr; for (LegacyEvent* e : events) { if (0 == strcmp(e->name(), "__start_profile")) { profiler_start = e; break; } } TORCH_CHECK(profiler_start, "Could not find __start_profile mark"); struct PairHash { size_t operator()( std::pair p) const noexcept { return std::hash()(p.first) ^ std::hash()(p.second); } }; std::unordered_map< std::pair, LegacyEvent*, PairHash> events_map; out << "[\n"; bool first = true; for (LegacyEvent* evt : events) { if (evt->kindStr() == "push") { events_map[std::make_pair(evt->handle(), evt->nodeId())] = evt; } else if (evt->kindStr() == "pop") { if (!first) { out << ",\n"; } first = false; auto it = events_map.find(std::make_pair(evt->handle(), evt->nodeId())); TORCH_CHECK(it != events_map.end(), "Unmatched pop event"); LegacyEvent* evt_start = it->second; events_map.erase(it); at::jit::TemplateEnv env; env.s("name", evt_start->name()); env.d("ts", profiler_start->cpuElapsedUs(*evt_start)); env.d("dur", evt_start->cpuElapsedUs(*evt)); env.d("tid", evt_start->threadId()); out << event_template.format(env); } } out << "]\n"; } RecordProfile::RecordProfile(std::ostream& out) : out_(out) { init(); } RecordProfile::RecordProfile(const std::string& filename) : file_(std::make_unique(filename)), out_(*file_) { init(); } void RecordProfile::init() { enableProfilerLegacy(torch::profiler::impl::ProfilerConfig( torch::profiler::impl::ProfilerState::CPU)); } RecordProfile::~RecordProfile() { try { thread_event_lists event_lists = disableProfilerLegacy(); std::vector events; for (auto& l : event_lists) { for (auto& e : l) { events.push_back(&e); } } processEvents(events); } catch (const std::exception& e) { LOG(ERROR) << e.what() << '\n'; } catch (...) { LOG(ERROR) << "Unknown error" << '\n'; } } void RecordProfile::processEvents(const std::vector& events) { writeProfilerEventsToStream(out_, events); } } // namespace torch::autograd::profiler