#pragma once #include #include #include #include #include #include #include #include #include constexpr auto kBackendDefaultTimeout = std::chrono::milliseconds(30 * 60 * 1000); namespace c10d { class TORCH_API Backend : public torch::CustomClassHolder { public: // Backend Options is a base struct that defines the basic options // when constructing a Backend. Each Backend subclass should // extend this struct and define its options if it wants to provide more // config options (beyond basic ones defined here) to end user. struct TORCH_API Options : torch::CustomClassHolder { explicit Options( std::string backend, std::chrono::milliseconds timeout = kBackendDefaultTimeout) : timeout(timeout), backend(std::move(backend)) {} ~Options() override = default; std::chrono::milliseconds timeout; // backend name // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::string backend; }; explicit Backend(int rank, int size); ~Backend() override = 0; int getRank() const { return rank_; } int getSize() const { return size_; } // Returns an unique opaque ID of this backend that can be used to correlate // with its collectives. int64_t getID() const { return reinterpret_cast(this); } virtual bool supportsSplitting() const { return false; } virtual void startCoalescing() { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not implement startCoalescing")); } virtual c10::intrusive_ptr endCoalescing() { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not implement endCoalescing")); } // Subclasses must override this method to return the backend name virtual const std::string getBackendName() const { TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented."); }; virtual c10::intrusive_ptr broadcast( std::vector& /* tensors */, const BroadcastOptions& /* opts */ = BroadcastOptions()) { TORCH_CHECK( false, c10::str("Backend ", getBackendName(), " does not support broadcast")); } virtual c10::intrusive_ptr allreduce( std::vector& /* tensors */, const AllreduceOptions& /* opts */ = AllreduceOptions()) { TORCH_CHECK( false, c10::str("Backend ", getBackendName(), " does not support allreduce")); } virtual c10::intrusive_ptr allreduce_sparse( std::vector& /* tensors */, const AllreduceOptions& /* opts */ = AllreduceOptions()) { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not support allreduce sparse")); } virtual c10::intrusive_ptr allreduce_coalesced( std::vector& /* tensors */, const AllreduceCoalescedOptions& /* opts */ = AllreduceCoalescedOptions()) { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not support allreduce_coalesced")); } virtual c10::intrusive_ptr reduce( std::vector& /* tensors */, const ReduceOptions& /* opts */ = ReduceOptions()) { TORCH_CHECK( false, c10::str("Backend ", getBackendName(), " does not support reduce")); } virtual c10::intrusive_ptr allgather( std::vector>& /* outputTensors */, std::vector& /* inputTensors */, const AllgatherOptions& /* opts */ = AllgatherOptions()) { TORCH_CHECK( false, c10::str("Backend ", getBackendName(), " does not support allgather")); } // Gathers a single tensor inputBuffer into a single buffer outputBuffer that // is interpreted as a contiguous collection of size inputBuffer * WORLD_SIZE. // For implementers of ProcessGroup API and advanced users only. // Note: this function will be deprecated in near future. virtual c10::intrusive_ptr _allgather_base( at::Tensor& /* outputBuffer */, at::Tensor& /* inputBuffer */, const AllgatherOptions& /* opts */ = AllgatherOptions()) { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not support _allgather_base")); } // This function is deprecated and will be moved out of Backend to comms: // * do not add dependencies on this function, // * do not implement it in your Backend, implement _allgather_base // instead. virtual c10::intrusive_ptr allgather_coalesced( std::vector>& /* outputTensorLists */, std::vector& /* inputTensors */, const AllgatherOptions& /* opts */ = AllgatherOptions()) { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not support allgather_coalesced")); } // This function is a coalesced version of `allgather_into_tensor` (currently // still named as `_allgather_base`). Each tensor in the vector corresponds to // an input/output of one `allgather_into_tensor` operation. virtual c10::intrusive_ptr allgather_into_tensor_coalesced( std::vector& /* outputs */, std::vector& /* inputs */, const AllgatherOptions& /* opts */ = AllgatherOptions()) { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not support allgather_into_tensor_coalesced")); } virtual c10::intrusive_ptr gather( std::vector>& /* outputTensors */, std::vector& /* inputTensors */, const GatherOptions& /* opts */ = GatherOptions()) { TORCH_CHECK( false, c10::str("Backend ", getBackendName(), " does not support gather")); } virtual c10::intrusive_ptr scatter( std::vector& /* outputTensors */, std::vector>& /* inputTensors */, const ScatterOptions& /* opts */ = ScatterOptions()) { TORCH_CHECK( false, c10::str("Backend ", getBackendName(), " does not support scatter")); } virtual c10::intrusive_ptr reduce_scatter( std::vector& /* outputTensors */, std::vector>& /* inputTensors */, const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not support reduce_scatter")); } virtual c10::intrusive_ptr _reduce_scatter_base( at::Tensor& /* outputBuffer */, at::Tensor& /* inputBuffer */, const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not support _reduce_scatter_base")); } // This function is a coalesced version of `reduce_scatter_tensor` (currently // still named as `_reduce_scatter_base`). Each tensor in the vector // corresponds to an input/output of one `reduce_scatter_tensor` operation. virtual c10::intrusive_ptr reduce_scatter_tensor_coalesced( std::vector& /* outputs */, std::vector& /* inputs */, const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not support reduce_scatter_tensor_coalesced")); } virtual c10::intrusive_ptr alltoall_base( at::Tensor& /* outputBuffer */, at::Tensor& /* inputBuffer */, std::vector& /* outputSplitSizes */, std::vector& /* inputSplitSizes */, const AllToAllOptions& /* opts */ = AllToAllOptions()) { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not support alltoall_base")); } virtual c10::intrusive_ptr alltoall( std::vector& /* outputTensors */, std::vector& /* inputTensors */, const AllToAllOptions& opts = AllToAllOptions()) { TORCH_CHECK( false, c10::str("Backend ", getBackendName(), " does not support alltoall")); } virtual void monitoredBarrier( const BarrierOptions& /* unused */, bool /* unused */ = false) { auto backendName = getBackendName(); TORCH_CHECK( false, c10::str( "Backend ", backendName, " does not support monitoredBarrier, only GLOO supports monitored barrier.")); } // Agrees on an initial sequence number for the whole group by having rank 0 // create it and broadcast it to other ranks using the store. Only implemented // for GLOO and NCCL backends currently. virtual void setSequenceNumberForGroup() { auto backendName = getBackendName(); TORCH_CHECK( false, c10::str( "Backend ", backendName, " does not yet support sequence numbers.")); } // Retrieves the current sequence number for the whole group, which should be // in sync. If the returned number is not consistent across the group, it // may indicate that there is some sort of collective desynchronization. virtual uint64_t getSequenceNumberForGroup() { auto backendName = getBackendName(); TORCH_CHECK( false, c10::str( "Backend ", backendName, " does not yet support sequence numbers.")); } virtual c10::intrusive_ptr send( std::vector& /* tensors */, int /* dstRank */, int /* tag */) { TORCH_CHECK( false, c10::str("Backend ", getBackendName(), " does not support send")); } virtual c10::intrusive_ptr recv( std::vector& /* tensors */, int /* srcRank */, int /* tag */) { TORCH_CHECK( false, c10::str("Backend ", getBackendName(), " does not support recv")); } virtual c10::intrusive_ptr recvAnysource( std::vector& /* tensors */, int /* tag */) { TORCH_CHECK( false, c10::str( "Backend ", getBackendName(), " does not support recvAnysource")); } virtual c10::intrusive_ptr barrier( const BarrierOptions& /* opts */ = BarrierOptions()) { TORCH_CHECK( false, c10::str("Backend ", getBackendName(), " does not support barrier")); } virtual void registerOnCompletionHook( std::function)>&& hook) { TORCH_CHECK( false, "Only ProcessGrouppNCCL supports onCompletion hook, but got ", getBackendName(), " backend."); } virtual void waitForPendingWorks() { TORCH_CHECK( false, "Only ProcessGrouppNCCL supports waitForPendingWorks, but got ", getBackendName(), " backend."); } virtual void enableCollectivesTiming() { TORCH_CHECK( false, "Backend ", getBackendName(), " is missing implementation of enableCollectivesTiming."); } bool hasHooks() const { return onCompletionHook_ != nullptr; } // Do not call this directly, use ProcessGroup::setGroupName instead. void setGroupUid(const std::string& pg_uid) { pg_uid_ = pg_uid; } const std::string& getGroupUid() const { return pg_uid_; } void setGroupDesc(const std::string& desc) { pg_desc_ = desc; } const std::string& getGroupDesc() const { return pg_desc_; } // See similar functions in ProcessGroup.hpp for context. std::optional getBoundDeviceId() const { return bound_device_id_; } // Perform an eager connect to the specified device if the backend supports // it. virtual void eagerConnectSingleDevice(at::Device device) { // no-op in the default case; this is an optimization some // backends may perform } void setBoundDeviceId(std::optional device) { if (device) { TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index"); } bound_device_id_ = device; } protected: // Implementations of this interface need to call this to setup // appropriate logging etc. void init(); // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const int rank_; // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const int size_; // Debug level setting. It is parsed once when ProcessGroup is constructed and // remains the same across use of this process group. DebugLevel dist_debug_level_; std::string pg_uid_; std::string pg_desc_; std::function)> onCompletionHook_; std::optional bound_device_id_; }; } // namespace c10d