#include #include #include #include #include #include #include #include #include #include #include namespace c10d { static ProcessGroup::BackendType strToBackendType(std::string_view backend) { if (backend == "undefined") { return ProcessGroup::BackendType::UNDEFINED; } else if (backend == "gloo") { return ProcessGroup::BackendType::GLOO; } else if (backend == "nccl") { return ProcessGroup::BackendType::NCCL; } else if (backend == "ucc") { return ProcessGroup::BackendType::UCC; } else if (backend == "mpi") { return ProcessGroup::BackendType::MPI; } else { return ProcessGroup::BackendType::CUSTOM; } } std::string opTypeToString(OpType opType) { switch (opType) { case OpType::BROADCAST: return "BROADCAST"; case OpType::ALLREDUCE: return "ALLREDUCE"; case OpType::ALLREDUCE_COALESCED: return "ALLREDUCE_COALESCED"; case OpType::REDUCE: return "REDUCE"; case OpType::ALLGATHER: return "ALLGATHER"; case OpType::_ALLGATHER_BASE: return "_ALLGATHER_BASE"; case OpType::ALLGATHER_COALESCED: return "ALLGATHER_COALESCED"; case OpType::GATHER: return "GATHER"; case OpType::SCATTER: return "SCATTER"; case OpType::REDUCE_SCATTER: return "REDUCE_SCATTER"; case OpType::ALLTOALL_BASE: return "ALLTOALL_BASE"; case OpType::ALLTOALL: return "ALLTOALL"; case OpType::SEND: return "SEND"; case OpType::RECV: return "RECV"; case OpType::RECVANYSOURCE: return "RECVANYSOURCE"; case OpType::BARRIER: return "BARRIER"; case OpType::UNKNOWN: return "UNKNOWN"; case OpType::_REDUCE_SCATTER_BASE: return "_REDUCE_SCATTER_BASE"; case OpType::COALESCED: return "COALESCED"; case OpType::_ALLREDUCE_SPARSE: return "_ALLREDUCE_SPARSE"; default: TORCH_INTERNAL_ASSERT(false, "Unknown op type!"); } return "UNKNOWN"; } bool isP2POp(OpType opType, bool batchP2P /*= false*/) { if (batchP2P) return false; return opType == OpType::SEND || opType == OpType::RECV || opType == OpType::RECVANYSOURCE; } c10::intrusive_ptr ProcessGroup::getBackend( c10::DeviceType deviceType) { // If there is a backend associated with this device type then return it if (deviceTypeToBackend_.find(deviceType) != deviceTypeToBackend_.end()) { return deviceTypeToBackend_.at(deviceType); } // Get the backend type associated with the device ProcessGroup::BackendType backendType{ProcessGroup::BackendType::UNDEFINED}; try { backendType = deviceTypeToBackendType_.at(deviceType); } catch (const std::out_of_range& e) { TORCH_CHECK( false, "No backend type associated with device type ", deviceType); } // Check if the backend has already been initialized if (backendTypeToBackend_.find(backendType) != backendTypeToBackend_.end()) { auto backend = backendTypeToBackend_.at(backendType); deviceTypeToBackend_[deviceType] = backend; return backend; } TORCH_CHECK( false, "Could not retrieve or create the backend ", backendType, " for device type ", deviceType); } ProcessGroup::ProcessGroup( const c10::intrusive_ptr<::c10d::Store>& store, int rank, int size, c10::intrusive_ptr options) : store_(store), rank_(rank), size_(size), options_(std::move(options)), backendType_(strToBackendType(options_->backend)), dist_debug_level_(debug_level()) { C10_LOG_API_USAGE_ONCE("c10d.process_group"); } ProcessGroup::ProcessGroup(int rank, int size) : rank_(rank), size_(size), backendType_(BackendType::UNDEFINED) {} ProcessGroup::~ProcessGroup() = default; void ProcessGroup::init() { C10_LOG_API_USAGE_ONCE( fmt::format("c10d.process_group_{}", getBackendName())); } const std::string& ProcessGroup::getGroupName() const { TORCH_CHECK(!deviceTypeToBackend_.empty(), "ProcessGroup name not set"); return deviceTypeToBackend_.begin()->second->getGroupUid(); } void ProcessGroup::setGroupName(const std::string& name) { for (auto& kv : deviceTypeToBackend_) { kv.second->setGroupUid(name); } } const std::string& ProcessGroup::getGroupDesc() const { return pg_desc_; } void ProcessGroup::setGroupDesc(const std::string& name) { pg_desc_ = name; // Also set the group desc for all backends for (auto& kv : deviceTypeToBackend_) { kv.second->setGroupDesc(name); } } void ProcessGroup::enableCollectivesTiming() { for (auto& kv : deviceTypeToBackend_) { kv.second->enableCollectivesTiming(); } } void ProcessGroup::release_resources() { store_.reset(); deviceTypeToBackend_.clear(); backendTypeToBackend_.clear(); } } // namespace c10d