#pragma once #include #include #include #include namespace c10d { namespace detail { class TCPServer; class TCPClient; struct SocketAddress { std::string host{}; std::uint16_t port{}; }; } // namespace detail struct TCPStoreOptions { static constexpr std::uint16_t kDefaultPort = 29500; std::uint16_t port = kDefaultPort; bool isServer = false; std::optional numWorkers = std::nullopt; bool waitWorkers = true; std::chrono::milliseconds timeout = Store::kDefaultTimeout; // A boolean value indicating whether multiple store instances can be // initialized with the same host:port pair. bool multiTenant = false; // If specified, and if isServer is true, the underlying TCPServer will take // over the bound socket associated to this fd. This option is useful to avoid // port assignment races in certain scenarios. std::optional masterListenFd = std::nullopt; // A boolean value indicating whether to use the experimental libUV backend. bool useLibUV = true; }; class TORCH_API TCPStore : public Store { public: static constexpr std::chrono::milliseconds kConnectRetryDelay{1000}; explicit TCPStore(std::string host, const TCPStoreOptions& opts = {}); [[deprecated("Use TCPStore(host, opts) instead.")]] explicit TCPStore( const std::string& masterAddr, std::uint16_t masterPort, std::optional numWorkers = std::nullopt, bool isServer = false, const std::chrono::milliseconds& timeout = kDefaultTimeout, bool waitWorkers = true); ~TCPStore() override; void set(const std::string& key, const std::vector& value) override; std::vector compareSet( const std::string& key, const std::vector& expectedValue, const std::vector& desiredValue) override; std::vector get(const std::string& key) override; int64_t add(const std::string& key, int64_t value) override; bool deleteKey(const std::string& key) override; bool check(const std::vector& keys) override; int64_t getNumKeys() override; void wait(const std::vector& keys) override; void wait( const std::vector& keys, const std::chrono::milliseconds& timeout) override; void append(const std::string& key, const std::vector& value) override; std::vector> multiGet( const std::vector& keys) override; void multiSet( const std::vector& keys, const std::vector>& values) override; bool hasExtendedApi() const override; // Waits for all workers to join. void waitForWorkers(); // Returns the hostname used by the TCPStore. const std::string& getHost() const noexcept { return addr_.host; } // Returns the port used by the TCPStore. std::uint16_t getPort() const noexcept { return addr_.port; } bool isLibUvBackend() const noexcept { return usingLibUv_; } // note(xilunwu): this function is only for internal testing void _splitSet(const std::string& key, const std::vector& data); std::string repr() const; private: int64_t incrementValueBy(const std::string& key, int64_t delta); void ping(); void validate(); std::vector doGet(const std::string& key); void doWait( c10::ArrayRef keys, std::chrono::milliseconds timeout); detail::SocketAddress addr_; std::shared_ptr server_; std::unique_ptr client_; std::optional numWorkers_; const std::string initKey_ = "init/"; const std::string keyPrefix_ = "/"; std::mutex activeOpLock_; bool usingLibUv_ = true; }; } // namespace c10d