#pragma once #ifdef USE_C10D_GLOO #include #include #include namespace c10d { class TORCH_API ProcessGroupWrapper : public Backend { public: explicit ProcessGroupWrapper( const c10::intrusive_ptr& backend, c10::intrusive_ptr glooBackend); const std::string getBackendName() const override; c10::intrusive_ptr broadcast( std::vector& data, const BroadcastOptions& opts = BroadcastOptions()) override; c10::intrusive_ptr allreduce( std::vector& data, const AllreduceOptions& opts = AllreduceOptions()) override; c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; c10::intrusive_ptr allgather( std::vector>& outputTensors, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr _allgather_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const AllgatherOptions& opts = AllgatherOptions()) override; // This function is deprecated and will be moved out of ProcessGroup to comms: // * do not add dependencies on this function, // * do not implement it in your ProcessGroup, implement _allgather_base // instead. c10::intrusive_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr gather( std::vector>& outputTensors, std::vector& inputTensors, const GatherOptions& opts = GatherOptions()) override; c10::intrusive_ptr scatter( std::vector& outputTensors, std::vector>& inputTensors, const ScatterOptions& opts = ScatterOptions()) override; c10::intrusive_ptr reduce_scatter( std::vector& outputTensors, std::vector>& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputSplitSizes, std::vector& inputSplitSizes, const AllToAllOptions& opts = AllToAllOptions()) override; c10::intrusive_ptr alltoall( std::vector& outputTensors, std::vector& inputTensors, const AllToAllOptions& opts = AllToAllOptions()) override; void monitoredBarrier(const BarrierOptions& opts, bool waitAllRanks = false) override; // Agrees on an initial sequence number for the whole group by having rank 0 // create it and broadcast it to other ranks using the store. Only implemented // for GLOO and NCCL backends currently. // dont implement this void setSequenceNumberForGroup() override; // Retrieves the current sequence number for the whole group, which should be // in sync. If the returned number is not consistent across the group, it // may indicate that there is some sort of collective desynchronization. uint64_t getSequenceNumberForGroup() override; // just call underlying c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; c10::intrusive_ptr _reduce_scatter_base( at::Tensor& outputBuffer, at::Tensor& inputBuffer, const ReduceScatterOptions& opts) override; void startCoalescing() override; c10::intrusive_ptr endCoalescing() override; c10::intrusive_ptr getWrappedPg() const; private: // Underlying process group that actual application collectives will be // dispatched to c10::intrusive_ptr backend_; // Gloo process group responsible for internal coordination such as monitored // barrier, sequence number checking, collective fingerprint collecting. c10::intrusive_ptr glooBackend_; // Conducts several checks to ensure that the underlying collective is well // formed with the goal of notifying the user about incorrect collective use // in the application. void runCollectiveChecks( OpType op_type, const std::vector& tensors); }; } // namespace c10d #endif // USE_C10D_GLOO