#pragma once #ifdef USE_C10D_UCC #include namespace c10d { #define RECORD_COMMS_TRACE( \ _comms_tracer, _work, _opType, _rank, _comm_size, _inTensors, _outTensors) \ do { \ if (torch_ucc_config.enable_comms_logger) { \ _comms_tracer->recordComms( \ opTypeToString(_opType), \ (uintptr_t)_work.get(), \ _rank, \ _comm_size, \ _inTensors, \ _outTensors); \ } \ } while (0) // interfaces to collect communication traces class TORCH_API CommTraceLogger : public torch::CustomClassHolder { private: std::vector comms_trace_; std::vector curBlocks_; /* unused */ std::vector curOutSplitSizes_; std::vector curInSplitSizes_; int curRoot_ = -1; unsigned long seqnum = 0; public: void setCurBlock(const std::string& name); /* unused */ void popBlock(); /* unused */ // record root info if applicable, e.g., broadcast, gather, scatter void recordOptionalInfo(int root = -1); // record input/output splits of Alltoallv void recordOptionalInfo( const std::vector& outputSplitSizes = {}, const std::vector& inputSplitSizes = {}); // record essential comms information void recordComms( const std::string& collName, const uintptr_t workReq = 0, const int rank = -1, const int world_size = -1, const std::vector& inputTensors = {}, const std::vector& outputTensor = {}); // return collected comms traces std::vector& getCommsTrace() { return comms_trace_; } }; } // namespace c10d #endif // USE_C10D_UCC