#pragma once #ifdef USE_C10D_GLOO #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace c10d { constexpr const char* GLOO_BACKEND_NAME = "gloo"; // ProcessGroupGloo implements Gloo bindings for c10d. // // All functions on this class are expected to be called in the same // order across processes in the group. This is the only way that we // can guarantee to match up the same calls across processes. For // multi-threaded usage of process groups, you can use consider using // multiple process group instances. // // The Gloo algorithms that this class calls into are cached by their // signature (see description of AlgorithmKey above). This cache works // as follows: every function call instantiates an AlgorithmKey and // looks in the cache for existing entries. If there is one, it is // removed from the cache and returned to the caller. If there are // none, a new entry is created and returned. If an entry was created // before, but is still in use, the call will block and wait until the // entry is returned to the cache. // // In the future, we hope to extend this to allow multiple entries per // key, to enable parallelism for a single key. The number of entries // per key must always be identical for all processes. This maximum // number can be automatically tuned, but only if we let a single // process take charge, and have it broadcast the limits. // class TORCH_API ProcessGroupGloo : public Backend { public: // AsyncWork is the Gloo specific superclass for asynchronous work items. // We can split asynchronous work into 3 phases: // 1) Sanity checks and prepare input (e.g. memcpy) // 2) Run operation on background thread // 3) Synchronize with completion on foreground thread // // There is state to be shared between these 3 phases and all of this state // is captured in the AsyncWork class and its derivatives. // // Note: while we are porting operations to use new style collectives, there // is a split between operations using the existing caching approach and // operations using the new AsyncWork base class. Over time we will port // all operations and perform needed cleanup. // // FIXME: This probably should be called WorkGloo since the work is executed // in sync mode by a background thread. class TORCH_API AsyncWork : public Work { public: explicit AsyncWork( std::vector> outputTensors, OpType opType, uint64_t seq, const char* profilingTitle = nullptr, const std::optional>& inputTensors = std::nullopt); ~AsyncWork() override = default; static void execute(const c10::intrusive_ptr& work); virtual void run() = 0; std::vector result() override; c10::intrusive_ptr getFuture() override; uint64_t getSequencenumber() const override; protected: friend class ProcessGroupGloo; private: void finishWorkGloo(); void finishWorkGlooError(const std::exception_ptr& eptr); inline void recordAsyncWorkProfilingInfo( const char* profilingTitle, const std::optional>& inputTensors); const std::vector> outputTensors_; c10::intrusive_ptr future_; std::function recordFunctionBeforeCallback_; const uint64_t seq_; }; // Wrap c10d store as Gloo store class TORCH_API GlooStore : public ::gloo::rendezvous::Store { public: GlooStore(const c10::intrusive_ptr<::c10d::Store>& store) : store_(store) {} void setUint(const std::string& key, const std::vector& value) { store_->set(key, value); } void set(const std::string& key, const std::vector& value) override { std::vector tmp(value.begin(), value.end()); store_->set(key, tmp); } std::vector getUint(const std::string& key) { auto value = store_->get(key); return value; } std::vector get(const std::string& key) override { auto value = store_->get(key); return std::vector(value.begin(), value.end()); } void wait(const std::vector& keys) override { store_->wait(keys, ::c10d::Store::kDefaultTimeout); } void wait( const std::vector& keys, const std::chrono::milliseconds& timeout) override { store_->wait(keys, timeout); } #ifdef GLOO_STORE_HAS_STORE_V2 bool has_v2_support() override { return store_->hasExtendedApi(); } std::vector> multi_get( const std::vector& keys) override { std::vector> res; for (auto& value : store_->multiGet(keys)) { res.emplace_back(value.begin(), value.end()); } return res; } void multi_set( const std::vector& keys, const std::vector>& values) override { std::vector> u_values; u_values.reserve(values.size()); for (auto& value : values) { u_values.emplace_back(value.begin(), value.end()); } store_->multiSet(keys, u_values); } void append(const std::string& key, const std::vector& value) override { std::vector tmp(value.begin(), value.end()); return store_->append(key, tmp); } int64_t add(const std::string& key, int64_t value) override { return store_->add(key, value); } #endif protected: c10::intrusive_ptr<::c10d::Store> store_; }; // For send and recv operations there is no need to pass them to the // thread pool as they are entirely completed by the device thread. // This work object is used to synchronize completion of the send or // recv operation. It keeps a reference to the tensor it is // operating on to prevent it from being deallocated while the // operation is still in flight. class TORCH_API SendWork : public Work { public: explicit SendWork( at::Tensor& tensor, std::unique_ptr<::gloo::transport::UnboundBuffer> buffer, uint64_t seq); bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; void abort() override; uint64_t getSequencenumber() const override; protected: at::Tensor tensor_; std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_; const uint64_t seq_; }; class TORCH_API RecvWork : public Work { public: explicit RecvWork( at::Tensor& tensor, std::unique_ptr<::gloo::transport::UnboundBuffer> buffer, OpType opType, uint64_t seq, const char* profilingTitle = nullptr); int sourceRank() const override; bool wait(std::chrono::milliseconds timeout = kNoTimeout) override; void abort() override; uint64_t getSequencenumber() const override; protected: at::Tensor tensor_; std::unique_ptr<::gloo::transport::UnboundBuffer> buffer_; int srcRank_; const uint64_t seq_; }; struct TORCH_API Options : public Backend::Options { explicit Options( std::chrono::milliseconds timeout = kBackendDefaultTimeout); // return intrusive_ptr of the object static c10::intrusive_ptr create( std::chrono::milliseconds timeout = kBackendDefaultTimeout) { return c10::make_intrusive(timeout); } std::vector> devices; int threads; }; const std::string getBackendName() const override { return std::string(GLOO_BACKEND_NAME); } // Helper functions to create a new device object. // They are static functions on this class to keep them logically // separate from the rest of the code base (e.g. torch/csrc/distributed). // Create new device instance for specific interface. static std::shared_ptr<::gloo::transport::Device> createDeviceForInterface( const std::string& interface); // Create new device instance for specific hostname or address. static std::shared_ptr<::gloo::transport::Device> createDeviceForHostname( const std::string& hostname); // Create new device instance. // It tries to resolve this machine's hostname and bind to that address. // If that fails (i.e. the hostname doesn't resolve to an address), it // falls back to binding to the loopback address. static std::shared_ptr<::gloo::transport::Device> createDefaultDevice(); // Create ProcessGroupGloo instance. static c10::intrusive_ptr createProcessGroupGloo( const c10::intrusive_ptr& store, int rank, int size, std::chrono::milliseconds timeout); explicit ProcessGroupGloo( const c10::intrusive_ptr& store, int rank, int size, c10::intrusive_ptr options = Options::create()); ~ProcessGroupGloo() override; c10::intrusive_ptr getOptions() { return options_; } c10::intrusive_ptr broadcast( std::vector& tensors, const BroadcastOptions& opts = BroadcastOptions()) override; c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; c10::intrusive_ptr allreduce_sparse( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; c10::intrusive_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override; c10::intrusive_ptr reduce( std::vector& tensors, const ReduceOptions& opts = ReduceOptions()) override; c10::intrusive_ptr _reduce_scatter_base( at::Tensor& outputTensor, at::Tensor& inputTensor, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; c10::intrusive_ptr _allgather_base( at::Tensor& output_tensor, at::Tensor& input_tensor, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr allgather( std::vector>& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr allgather_coalesced( std::vector>& output_lists, std::vector& input_list, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr allgather_into_tensor_coalesced( std::vector& outputs, std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; c10::intrusive_ptr gather( std::vector>& outputs, std::vector& inputs, const GatherOptions& opts = GatherOptions()) override; c10::intrusive_ptr scatter( std::vector& outputs, std::vector>& inputs, const ScatterOptions& opts = ScatterOptions()) override; c10::intrusive_ptr reduce_scatter( std::vector& outputs, std::vector>& inputs, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; c10::intrusive_ptr reduce_scatter_tensor_coalesced( std::vector& outputTensors, std::vector& inputTensors, const ReduceScatterOptions& opts = ReduceScatterOptions()) override; c10::intrusive_ptr alltoall_base( at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, std::vector& inputCounts, const AllToAllOptions& opts = AllToAllOptions()) override; c10::intrusive_ptr send( std::vector& tensors, int dstRank, int tag) override; c10::intrusive_ptr recv( std::vector& tensors, int srcRank, int tag) override; c10::intrusive_ptr recvAnysource( std::vector& tensors, int tag) override; c10::intrusive_ptr barrier( const BarrierOptions& opts = BarrierOptions()) override; void enableCollectivesTiming() override; const std::unique_ptr<::gloo::rendezvous::Store>& _getStore() const { return store_; } // Similar to barrier(), but blocks rank 0 until all other ranks have // acknowledged that they are alive (through send/recv from rank 0). Rank 0 // is able to report all failed ranks if waitAllRanks = true, otherwise // reports the first rank it detected as failed. void monitoredBarrier( const BarrierOptions& opts = BarrierOptions(), bool waitAllRanks = false) override; // Agrees on an initial sequence number for the whole group by having rank 0 // create it and broadcast it to other ranks using the store. void setSequenceNumberForGroup() override; // Retrieves the current sequence number for the whole group, which should be // in sync. If the returned number is not consistent across the group, it // may indicate that there is some sort of collective desynchronization. uint64_t getSequenceNumberForGroup() override; int getNumThreads() { return options_->threads; } protected: std::unique_ptr<::gloo::rendezvous::Store> store_; const c10::intrusive_ptr options_; // Every Gloo context represents a set of connections to its peers. // In order to use more than one device (or allow for parallelism on // a single device), you need multiple contexts. std::vector> contexts_; std::vector threads_; bool stop_; // Incremented for every collective we kick off. // The value is used as tag for collective operations. Collectives are kicked // off in identical order across processes. Therefore the tag can be used // to match up operations during concurrent execution. uint32_t collectiveCounter_; // Returns next collective tag to use (uses collectiveCounter_). uint32_t nextTag(); // Returns the context to use for the specified tag. // With `nextTag` returning an increasing number, this should lead // to contexts being used in a round-robin fashion. std::shared_ptr<::gloo::Context> getContext(uint32_t tag); // Entrypoint for worker threads. void runLoop(int workerIndex); // Queue work to run on worker thread. void enqueue(c10::intrusive_ptr work); // Keep both a queue of pending work, and a vector with in progress work. // Both of these can only be mutated when holding the queue lock. // We keep both around instead of just the queue, so we can grab a weak_ptr // to all in progress and pending work when executing a barrier. // When executing a barrier, we need to ensure that all prior work // has completed before completing itself. std::deque> workQueue_; std::vector> workInProgress_; std::mutex workMutex_; std::condition_variable workProduceCV_; std::condition_variable workConsumeCV_; uint64_t seq_{0}; }; } // namespace c10d #endif // USE_C10D_GLOO