#pragma once #include #include #include #include #include #include namespace torch { class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { public: ParamCommsDebugInfo() = default; ParamCommsDebugInfo( std::tuple pgName, int rank, std::string&& collName, int64_t inNelems, int64_t outNelems, at::ScalarType dType, std::vector inSplitSizes, std::vector outSplitSizes, int globalRankStart, int globalRankStride, int worldSize); ~ParamCommsDebugInfo() override = default; const std::string getProcessGroupName() const { return std::get<0>(pgName_); } const std::string getProcessGroupDesc() const { return std::get<1>(pgName_); } int getRank() const { return rank_; } int getWorldSize() const { return worldSize_; } int getGlobalRankStart() const { return globalRankStart_; } int getGlobalRankStride() const { return globalRankStride_; } const std::string getCollectiveName() const { return collectiveName_; } int64_t getInMessageNelems() const { return inMessageNelems_; } int64_t getOutMessageNelems() const { return outMessageNelems_; } at::ScalarType getDType() const { return dType_; } const std::vector& getInputSplitSizes() const { return inputSplitSizes_; } const std::vector& getOutputSplitSizes() const { return outputSplitSizes_; } const std::vector& getGroupRanks() const { return groupRanks_; } private: std::tuple pgName_; // int rank_{}; int worldSize_{}; std::string collectiveName_; int64_t inMessageNelems_{}; int64_t outMessageNelems_{}; at::ScalarType dType_ = at::kByte; std::vector inputSplitSizes_; std::vector outputSplitSizes_; int globalRankStart_{}; int globalRankStride_{}; std::vector groupRanks_{}; }; #define RECORD_PARAM_COMMS( \ seq, \ pgName, \ rank, \ collName, \ inNelems, \ outNelems, \ dType, \ inSplitSizes, \ outSplitSizes, \ globalRankStart, \ globalRankStride, \ worldSize) \ auto paramCommsInfo = std::make_shared( \ pgName, \ rank, \ collName, \ inNelems, \ outNelems, \ dType, \ inSplitSizes, \ outSplitSizes, \ globalRankStart, \ globalRankStride, \ worldSize); \ c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ std::initializer_list paramList = { \ c10::IValue(seq), \ pgName, \ rank, \ collName, \ inSplitSizes, \ outSplitSizes, \ globalRankStart, \ globalRankStride, \ worldSize}; \ c10::ArrayRef paramInputs(paramList); \ RECORD_FUNCTION(at::kParamCommsCallName, paramInputs); #define RECORD_PARAM_COMMS_DATA( \ seq, \ pgName, \ InputTensors, \ OutputTensors, \ rank, \ collName, \ inNelems, \ outNelems, \ dType, \ inSplitSizes, \ outSplitSizes, \ globalRankStart, \ globalRankStride, \ worldSize) \ auto paramCommsInfo = std::make_shared( \ pgName, \ rank, \ collName, \ inNelems, \ outNelems, \ dType, \ inSplitSizes, \ outSplitSizes, \ globalRankStart, \ globalRankStride, \ worldSize); \ c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ std::initializer_list paramList = { \ c10::IValue(InputTensors), \ c10::IValue(seq), \ pgName, \ rank, \ collName, \ inSplitSizes, \ outSplitSizes, \ globalRankStart, \ globalRankStride, \ worldSize}; \ c10::ArrayRef paramInputs(paramList); \ RECORD_FUNCTION_WITH_INPUTS_OUTPUTS( \ at::kParamCommsCallName, \ paramInputs, \ std::vector(1, c10::IValue(OutputTensors))); } // namespace torch