#pragma once #include #include #include namespace torch { namespace distributed { namespace autograd { // Singleton class per worker which is responsible for storing the distributed // autograd context for each autograd pass and also cleans up data for an // autograd pass once its done. // // Each autograd pass is assigned a unique autograd_context_id and all data for // that pass (DistAutogradContext) is stored in this container indexed by the // autograd_context_id. The autograd_context_id itself is a 64 bit globally // unique id. The first 16 bits is the worker_id and the next 48 bits is an // auto-incrementing id for each worker. // // This container is also responsible for maintaining a globally unique message // id, which is used to associate send/recv autograd function pairs. The format // is similar to the autograd_context_id where we have a 64 bit integer with // first 16 bits being the worker id and next 48 bits are auto-incrementing. class TORCH_API DistAutogradContainer { public: explicit DistAutogradContainer(uint32_t num_shards); // One time initialization of the container. static DistAutogradContainer& init(int64_t worker_id); // Retrieve the singleton instance of the container, ensures we have // initialized the container. static DistAutogradContainer& getInstance(); // Create a new context for a distributed autograd pass. const ContextPtr newContext(); // Clean up resources for a given context_id once the autograd pass is done. // Sends RPC to other workers this worker knows about, telling them to clean // up their context as well. Throws an exception if the context_id does not // exist. void releaseContext(int64_t context_id); // Releases an autograd context if it is present on this node. Also sends RPC // to other workers this worker knows about, telling them to clean up their // context. Does nothing if it is not present. void releaseContextIfPresent(int64_t context_id); // Checks if the passed in context_id is valid. void isValidContext(int64_t context_id); // Retrieve the autograd context for a given context_id. ContextPtr retrieveContext(int64_t context_id); // Retrieves the currently active autograd context for the current thread. ContextPtr currentContext(); // Checks whether or not the current thread has a valid autograd context. bool hasValidContext() const; // Generate a new autograd_message_id for send/recv autograd functions. int64_t newAutogradMessageId(); // Creates a new autograd context with the provided context_id. If a context // already exists with the provided context_id, we just return it. // This does not set the current context for the current thread. ContextPtr getOrCreateContext(int64_t context_id); // Retrieves the maximum possible autograd_context_id/autograd_message_id that // can be generated by this worker. int64_t getMaxId(); // Retrieves the worker ID for this node rpc::worker_id_t getWorkerId() const; // Can set current context id if there is no valid context yet static void setCurrentContextId(int64_t contextId); // Forcibly sets the thread local current context id. Should only be used in // cases where you know what you're doing and need to override the thread // local. Otherwise, use setCurrentContextId instead. static void forceCurrentContextId(int64_t contextId); // Clear current context id void clearCurrentContext(); // Returns the number of autograd contexts in the container. size_t numAutogradContexts() const; // Returns the current thread local context id for this thread. static int64_t currentContextId(); DistAutogradContainer(const DistAutogradContainer&) = delete; DistAutogradContainer& operator=(const DistAutogradContainer&) = delete; DistAutogradContainer(DistAutogradContainer&&) = delete; DistAutogradContainer& operator=(DistAutogradContainer&&) = delete; private: // Number of shards for the map storing autograd contexts. We'd like this // to be a power of 2 and we don't expect a value much higher than the // number of cores would provide much benefit. static constexpr uint32_t kNumDefaultShards = 128; // Use cache line size for alignment. static constexpr int kCacheLineSize = 64; // Structure holding one shard of the sharded autograd context map with its // associated lock. Align to cache line size to avoid contention between // adjacent entries. struct alignas(kCacheLineSize) ContextsShard { // Lock for this shard. mutable std::mutex lock; // Map storing autograd contexts for this shard. std::unordered_map contexts; }; DistAutogradContainer() = delete; ~DistAutogradContainer() = default; static DistAutogradContainer& getInstanceInternal(); // Retrieve the shard for given context_id. ContextsShard& getShard(int64_t context_id); // Sends an RPC to the workers that have a context corresponding to passed in // context_id. This function should be called with the lock. void sendReleaseContextRpc( const std::unordered_set& workerIds, int64_t context_id); // Erase context_id from the autograd context map, and reset the thread local // current context id if it corresponds to the passed in context id. This // function should be called with the lock. void eraseContextIdAndReset(ContextsShard& shard, int64_t context_id); // Compute the number of shards for the autograd_contexts_ map. static uint32_t computeNumShards(); // Auto incrementing context id used to identify unique autograd passes. // Initialized with the first 16 bits being the worker_id. std::atomic next_context_id_; // Unique id to identify a worker in the distributed setting. int16_t worker_id_; // Whether or not the container has been initialized appropriately. bool initialized_; // Sharded autograd context map. std::vector autograd_contexts_; // Number of shards for the sharded autograd_contexts_ map. uint32_t num_shards_; // Autograd message id to identify unique send/recv autograd function pairs. std::atomic next_autograd_message_id_; // Maximum allowed value for autograd_context_id or autograd_message_id. int64_t max_id_; }; } // namespace autograd } // namespace distributed } // namespace torch