#pragma once #include #include #include #include #include #include #include #include #include namespace torch::distributed::rpc { using DeviceMap = std::unordered_map; // Default RPC timeout constexpr float kDefaultRpcTimeoutSeconds = 60; // Unset RPC timeout. This is the value agent::send() will have if user does not // pass in a specific timeout, and indicates that we must use the default // timeout for RPCs. constexpr float kUnsetRpcTimeout = -1; constexpr auto kDefaultInitMethod = "env://"; constexpr float kSecToMsConversion = 1000; constexpr auto kRpcTimeoutErrorStr = "RPC ran for more than set timeout ({} ms) and will now be marked with an error"; using steady_clock_time_point = std::chrono::time_point; // Input is qualified name string, output is JIT StrongTypePtr // Same as jit::TypeResolver, did not import jit::TypeResolver to here // because it could introduce cyclic dependencies. using TypeResolver = std::function; struct TORCH_API RpcBackendOptions { RpcBackendOptions() : RpcBackendOptions(kDefaultRpcTimeoutSeconds, kDefaultInitMethod) {} RpcBackendOptions(float rpcTimeoutSeconds, std::string initMethod) : rpcTimeoutSeconds(rpcTimeoutSeconds), initMethod(std::move(initMethod)) { TORCH_CHECK(rpcTimeoutSeconds >= 0, "RPC Timeout must be non-negative"); } float rpcTimeoutSeconds; std::string initMethod; }; // A globally unique ID to identify an RpcAgent struct TORCH_API WorkerInfo : torch::CustomClassHolder { WorkerInfo(std::string name, int64_t id); WorkerInfo(std::string name, worker_id_t id); bool operator==(const WorkerInfo& rhs) { return (id_ == rhs.id_) && (name_ == rhs.name_); } static constexpr size_t MAX_NAME_LEN = 128; const std::string name_; const worker_id_t id_; }; struct TORCH_API RegisterWorkerInfoOnce { RegisterWorkerInfoOnce(); }; TORCH_API std::ostream& operator<<( std::ostream& os, const WorkerInfo& workerInfo); // Struct for options to configure the RPC Retry protocol. struct TORCH_API RpcRetryOptions { // Using a default constructor like all other Options structs in the RPC // codebase. TORCH_CHECKs for input validation are done in the // sendWithRetries function. RpcRetryOptions() = default; // Maximum number of times we will retry the RPC int maxRetries{5}; // Initial duration between consecutive RPC send attempts std::chrono::milliseconds rpcRetryDuration{std::chrono::milliseconds(1000)}; // Constant for exponential backoff used while calculating future wait // durations float retryBackoff{1.5}; }; // Struct that stores all the metadata needed to retry a given RPC. struct TORCH_API RpcRetryInfo { RpcRetryInfo( const WorkerInfo& to, c10::intrusive_ptr message, c10::intrusive_ptr originalFuture, int retryCount, RpcRetryOptions options) : to_(to), message_(std::move(message)), originalFuture_(std::move(originalFuture)), retryCount_(retryCount), options_(options) {} const WorkerInfo& to_; c10::intrusive_ptr message_; // Future that is returned to the caller of sendWithRetries(). c10::intrusive_ptr originalFuture_; // Number of send attempts completed so far. int retryCount_; RpcRetryOptions options_; }; // ``RpcAgent`` is the base class for sending and receiving RPC messages. It // provides a unified ``send`` API for both request and response messages, and // will invoke the given ``RequestCallback`` to process received requests. It // should immediately become ready to serve request and accept response after // construction. class TORCH_API RpcAgent { public: // `WorkerInfo` is the globally unique identifier for this RpcAgent instance. // It contains a ``name_`` field and an ``id_`` field. ``name_`` is the // globally unique name for this ``RpcAgent``. It is up to the ``RpcAgent`` // implementation to determine how to resolve names. ``id_`` is the globally // unique ID for this ``RpcAgent``. This should be determined by the // ``RpcAgent`` implementation. // The ``RequestCallback`` will be invoked to handle received requests. This // ``RpcAgent`` base class makes no assumption on the thread-safeness of the // ``RequestCallback``. ``RpcAgent`` implementations need to make sure that // its threading model conform to ``RequestCallback``'s requirement. // NB: RpcAgent implementations should not start serving requests until // ``start()`` is called, as there could be other contexts that have not been // initialized yet at this time. RpcAgent( WorkerInfo id, std::unique_ptr cb, std::chrono::milliseconds rpcTimeout); virtual ~RpcAgent(); // Send a message to the ``RpcAgent`` of id ``to`` and returns a // ``JitFuture`` ptr. The implementation must be asynchronous, i.e., it // cannot block until it receives the response. // // If ``message.isRequest()`` is true, the ``JitFuture`` will be // completed when the response arrives. For other message types, the Future // should be ignored by the caller. virtual c10::intrusive_ptr send( const WorkerInfo& to, c10::intrusive_ptr message, const float rpcTimeoutSeconds = kUnsetRpcTimeout, const DeviceMap& deviceMap = {}) = 0; // Retries sending the message up to maxRetries times until an ACK is // received. The duration between consecutive sends is increased over // time using an exponential backoff algorithm. // // Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a // ``JitFuture`` ptr, just like send(). Caller can specify the maximum // number of retries for this RPC (default is 5), initial duration between // sends (default is 1000ms), and backoff constant (default is 1.5) by // passing in the RpcRetryOptions struct. This API might end up // executing a method twice on the remote end (it does not guarantee // exactly-once semantics). Therefore, the user must ensure their requests // are idempotent. c10::intrusive_ptr sendWithRetries( const WorkerInfo& to, c10::intrusive_ptr message, RpcRetryOptions retryOptions = RpcRetryOptions()); // Return a reference to the ``WorkerInfo`` of this RpcAgent. // NB: not using ``std::optional`` here because we might // need to create a separate RPC API lib and avoid forcing all ``RpcAgent`` // implementations to depend on libtorch. const WorkerInfo& getWorkerInfo() const; // Return a reference to the ``WorkerInfo`` of the given ``workerName``. virtual const WorkerInfo& getWorkerInfo( const std::string& workerName) const = 0; virtual const WorkerInfo& getWorkerInfo(worker_id_t id) const = 0; virtual std::vector getWorkerInfos() const = 0; // Retrieve the timeout for all RPCs. inline std::chrono::milliseconds getRpcTimeout() const { return rpcTimeout_.load(); } // Set the timeout for all RPCs inline void setRpcTimeout(const std::chrono::milliseconds& rpcTimeout) { rpcTimeout_.store(rpcTimeout); } // Call sync and join all internal threads. This method should be called // before every RPC process exits. virtual void join(bool shutdown = false, float timeout = 0) = 0; // Synchronize the this process with other ``RpcAgent`` processes. Block until // all ``RpcAgent``s reach this method and send all pending messages. virtual void sync() = 0; // Sets up backend-agnostic state for accepting requests. Currently, this // entails setting rpcAgentRunning_ to true, creating the retry thread, and // calling the backend's startImpl. void start(); // Derived classes must override this function to start accepting requests. // This is used to initialize any backend-specific state. Users must call // start, not startImpl, to initialize the RPC Agent. virtual void startImpl() = 0; // Stop accepting requests and shutdown the RPC framework as soon as possible // by terminating all RPC threads. void shutdown(); // Derived classes must override this function to start accepting requests. // THis is used to clean up any backend-specific state. Users must call // shutdown, not shutdownImpl, to shutdown the RPC Agent. virtual void shutdownImpl() = 0; // Check if current RPC agent is set. static bool isCurrentRpcAgentSet(); // Retrieve the valid current RPC agent. static std::shared_ptr getCurrentRpcAgent(); // Set the current RPC agent. static void setCurrentRpcAgent(std::shared_ptr rpcAgent); // Retrieve metrics as KV map virtual std::unordered_map getMetrics() = 0; // Retrieve debug info in addition to metrics as KV map virtual std::unordered_map getDebugInfo(); // Flag to control whether GIL wait times // should be profiled or not. void enableGILProfiling(bool flag); // Retrieve wheher we should profile GIL wait times or not. bool isGILProfilingEnabled(); // Set type resolver that will be passed to JIT pickler to resolver type Ptr // based on type str. void setTypeResolver(std::shared_ptr typeResolver); // Get the type resolver std::shared_ptr getTypeResolver(); // Retrieves the device map for the provided destination worker. virtual DeviceMap getDeviceMap(const WorkerInfo& dst) const; // Retrieve the (non-CPU) devices that are supported by the agent. virtual const std::vector& getDevices() const; protected: // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) const WorkerInfo workerInfo_; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) const std::unique_ptr cb_; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::atomic rpcTimeout_; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::atomic profilingEnabled_; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::shared_ptr typeResolver_; // Atomic boolean indicating whether this agent is running. It controls // whether several background threads should be running. It is set in // RpcAgent::start() and unset in the derived class shutdown(). // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) std::atomic rpcAgentRunning_; private: static std::shared_ptr currentRpcAgent_; // Add GIL wait time data point to metrics virtual void addGilWaitTime(const std::chrono::microseconds gilWaitTime) = 0; friend class PythonRpcHandler; // Map that stores metadata for RPC's that may need to be re-tried as well as // the timepoint at which we should re-try them. std::map< steady_clock_time_point, std::unordered_set>> rpcRetryMap_; // Thread that checks for retryable RPC's in the rpcRetryMap_ and sleeps until // the next unACKed RPC's timeout has expired. std::thread rpcRetryThread_; // Function that rpcRetryThread_ calls in a loop as long as RpcAgent is // running. void retryExpiredRpcs(); // This is the callback attached to futures corresponding to send retries. // This handles 3 cases: 1). send was completed, 2). send failed with an // error and we've done maxRetries failed send attempts, and 3). send // failed with an error and we have more retries to go. In case 1, we mark // the original future as complete. In case 2, we mark the future with an // error and do not retry again. In case 3, we move the RpcRetryInfo struct // to another time point in the map to schedule the RPC for a future send. void rpcRetryCallback( JitFuture& message, steady_clock_time_point newTime, std::shared_ptr earliestRpc); // Function that uses the exponential backoff algorithm to compute the next // time point to retry a given RPC. inline steady_clock_time_point computeNewRpcRetryTime( RpcRetryOptions& options, int retryCount) { // The exponential backoff algorithm being used here is: // newTime = timeNow + (retryDuration * (backoffConstant ^ retryCount)). std::chrono::milliseconds timedelta = std::chrono::duration_cast( options.rpcRetryDuration * pow(options.retryBackoff, retryCount)); return std::chrono::time_point_cast( std::chrono::steady_clock::now() + timedelta); } // Condition Variable to signal when the rpcRetryMap_ has been populated. std::condition_variable rpcRetryMapCV_; // Mutex to protect RpcRetryMap_. std::mutex rpcRetryMutex_; }; } // namespace torch::distributed::rpc namespace std { template <> struct hash { std::size_t operator()( const torch::distributed::rpc::WorkerInfo& worker_info) const noexcept { return worker_info.id_; } }; } // namespace std