#pragma once #include #include #include #include #include #include #include #include #include #include namespace torch::autograd::profiler { enum class C10_API_ENUM EventKind : uint16_t { Mark, PushRange, PopRange, MemoryAlloc, }; // To be deprecated, once we switch to Kineto profiling struct TORCH_API LegacyEvent { LegacyEvent( EventKind kind, at::StringView name, uint16_t thread_id, bool record_cuda, at::RecordFunctionHandle handle = 0, std::vector>&& shapes = {}, int64_t node_id = -1, bool is_async = false) : name_(std::move(name)), kind_(kind), thread_id_(thread_id), handle_(handle), shapes_(std::move(shapes)), node_id_(node_id), is_async_(is_async) { record(record_cuda); } // Constructor to be used in conjunction with LegacyEvent::fromIValue. LegacyEvent( EventKind kind, at::StringView name, uint16_t thread_id, at::RecordFunctionHandle handle, std::vector>&& shapes, int64_t node_id, bool is_remote, int64_t cpu_memory_usage, int64_t cpu_ns, bool cuda_recorded, int64_t cuda_memory_usage = 0, c10::DeviceIndex device = -1, double cuda_us = -1) : cpu_ns_(cpu_ns), name_(std::move(name)), kind_(kind), thread_id_(thread_id), handle_(handle), shapes_(std::move(shapes)), cpu_memory_usage_(cpu_memory_usage), cuda_memory_usage_(cuda_memory_usage), device_(device), node_id_(node_id), is_remote_(is_remote), cuda_us_(static_cast(cuda_us)) { // Sanity check values that were deserialized TORCH_INTERNAL_ASSERT(cpu_ns_ > 0); if (cuda_recorded) { TORCH_INTERNAL_ASSERT(device_ >= 0); TORCH_INTERNAL_ASSERT(cuda_us_ >= 0); } } // Returns IValues corresponding to event structure, to be used for // serialization. at::IValue toIValue() const; // Reconstructs an event from IValues given by toIValue. static LegacyEvent fromIValue(const at::IValue& eventIValue); void record(bool record_cuda); std::string kindStr() const { switch (kind_) { case EventKind::Mark: return "mark"; case EventKind::PushRange: return "push"; case EventKind::PopRange: return "pop"; case EventKind::MemoryAlloc: return "memory_alloc"; } throw std::runtime_error("unknown event kind"); } EventKind kind() const { return kind_; } const char* name() const { return name_.str(); } uint64_t threadId() const { return thread_id_; } std::vector> shapes() const { return shapes_; } double cpuElapsedUs(const LegacyEvent& e) const { return static_cast(e.cpu_ns_ - cpu_ns_) / (1000.0); } void setCpuUs(int64_t cpu_us) { cpu_ns_ = cpu_us * 1000; } double cpuUs() const { return static_cast(cpu_ns_) / (1000.0); } double cudaElapsedUs(const LegacyEvent& e) const; bool hasCuda() const { return cuda_event != nullptr || (isRemote() && device_ != -1); } c10::DeviceIndex device() const { return device_; } void updateMemoryStats(int64_t alloc_size, c10::Device device) { if (device.is_cuda() || device.type() == c10::DeviceType::HIP) { cuda_memory_usage_ = alloc_size; } else if ( device.is_cpu() || device.type() == c10::DeviceType::MKLDNN || device.type() == c10::DeviceType::IDEEP) { cpu_memory_usage_ = alloc_size; } else { LOG(WARNING) << "Unsupported memory profiling device: " << device; } } int64_t cpuMemoryUsage() const { return cpu_memory_usage_; } int64_t cudaMemoryUsage() const { return cuda_memory_usage_; } at::RecordFunctionHandle handle() const { return handle_; } // Node ID corresponding to this event. int64_t nodeId() const { return node_id_; } // Set Node ID on this event. void setNodeId(int64_t node_id) { node_id_ = node_id; } void setName(at::StringView newName_) { name_ = std::move(newName_); } bool isRemote() const { return is_remote_; } void setCudaUs(int64_t cuda_us) { cuda_us_ = cuda_us; } void setSequenceNr(int64_t sequence_nr) { sequence_nr_ = sequence_nr; } int64_t sequenceNr() const { return sequence_nr_; } void setCorrelationId(uint64_t correlation_id) { correlation_id_ = correlation_id; } uint64_t correlationId() const { return correlation_id_; } const std::vector& stack() const { return stack_; } void setStack(const std::vector& stack) { stack_ = stack; } uint64_t fwdThreadId() const { return fwd_thread_id_; } void setFwdThreadId(uint64_t fwd_thread_id) { fwd_thread_id_ = fwd_thread_id; } uint8_t scope() const { return scope_; } void setScope(uint8_t scope) { scope_ = scope; } const std::unordered_map& extraArgs() const { return extra_args_; } void setExtraArgs(std::unordered_map&& save_args) { extra_args_ = std::move(save_args); } uint64_t flops() { return flops_; } bool isAsync() { return is_async_; } void setFlops(uint64_t flops) { flops_ = flops; } private: // signed to allow for negative intervals, initialized for safety. int64_t cpu_ns_ = 0; at::StringView name_; EventKind kind_; uint64_t thread_id_; uint64_t fwd_thread_id_{0}; at::RecordFunctionHandle handle_{0}; std::vector> shapes_; int64_t cpu_memory_usage_ = 0; int64_t cuda_memory_usage_ = 0; c10::DeviceIndex device_ = -1; torch::profiler::impl::ProfilerVoidEventStub cuda_event = nullptr; int64_t node_id_ = 0; bool is_remote_ = false; int64_t cuda_us_ = -1; int64_t sequence_nr_ = -1; bool is_async_ = false; std::vector stack_; uint8_t scope_{0}; uint64_t correlation_id_{0}; // Extra arguments for computing op flops std::unordered_map extra_args_; uint64_t flops_ = 0; }; // a linked-list of fixed sized vectors, to avoid // a std::vector resize from taking a large amount of time inside // a profiling event struct RangeEventList { RangeEventList() { events_.reserve(kReservedCapacity); } template void record(Args&&... args) { std::lock_guard guard(mutex_); events_.emplace_back(std::forward(args)...); } std::vector consolidate() { std::lock_guard lock(mutex_); std::vector result; result.insert( result.begin(), std::make_move_iterator(events_.begin()), std::make_move_iterator(events_.end())); events_.erase(events_.begin(), events_.end()); return result; } size_t size() { std::lock_guard lock(mutex_); return events_.size(); } private: // This mutex is used to serialize access when different threads are writing // to the same instance of RangeEventList. std::mutex mutex_; std::vector events_; static const size_t kReservedCapacity = 1024; }; // A struct to control settings of disableProfiler options. struct TORCH_API ProfilerDisableOptions { ProfilerDisableOptions() = default; ProfilerDisableOptions(bool shouldCleanupTLSState, bool shouldConsolidate) : cleanupTLSState(shouldCleanupTLSState), consolidate(shouldConsolidate) {} // Whether we should clean up profiler states that are thread local, such as // ThreadLocalDebugInfo and thread local RecordFunction callbacks. bool cleanupTLSState = true; // Whether we should consolidate all currently recorded profiled events. If // false, will not consolidate and other threads can continue to write to the // event lists. bool consolidate = true; }; // NOTE: profiler mode is thread local, with automatic propagation // across thread boundary (e.g. at::launch tasks) TORCH_API void enableProfilerLegacy( const torch::profiler::impl::ProfilerConfig&); using thread_event_lists = std::vector>; TORCH_API thread_event_lists disableProfilerLegacy( std::optional profilerDisableOptions = std::nullopt); // adds profiledEvents to the current thread local recorded events. Each event // will be marked with node ID given by fromNodeId. TORCH_API void addEventList(std::vector&& profiledEvents); // Writes profiled events to a stream. TORCH_API void writeProfilerEventsToStream( std::ostream& out, const std::vector& events); // Usage: // { // RecordProfile guard("filename.trace"); // // code you want to profile // } // Then open filename.trace in chrome://tracing struct TORCH_API RecordProfile { RecordProfile(std::ostream& out); RecordProfile(const std::string& filename); ~RecordProfile(); private: void init(); std::unique_ptr file_; std::ostream& out_; void processEvents(const std::vector& events); }; // A guard that enables the legacy profiler, taking in an optional callback to // process the results Usage: // { // TLSLegacyProfilerGuard g([](thread_event_lists profilerResults) { // // process profilerResults // }); // Code to profile // } struct TORCH_API TLSLegacyProfilerGuard { explicit TLSLegacyProfilerGuard( const torch::profiler::impl::ProfilerConfig& cfg, std::optional> resultCallback = std::nullopt, std::optional profilerDisableOptions = std::nullopt) : cb_(std::move(resultCallback)), profilerDisableOptions_(profilerDisableOptions) { enableProfilerLegacy(cfg); } ~TLSLegacyProfilerGuard() { thread_event_lists event_lists = disableProfilerLegacy(profilerDisableOptions_); if (cb_) { try { (*cb_)(event_lists); } catch (const std::exception& e) { LOG(ERROR) << "Got error processing profiler events: " << e.what(); } } } private: std::optional> cb_; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::optional profilerDisableOptions_; }; } // namespace torch::autograd::profiler