#pragma once #include #include #include #include #include #include #include namespace torch { namespace distributed { namespace autograd { class RecvRpcBackward; // DistAutogradContext which stores information for a single distributed // autograd pass on a worker. class TORCH_API DistAutogradContext { public: using GradCallback = std::function; explicit DistAutogradContext(int64_t contextId); // Retrieves the autograd context id for this context. int64_t contextId() const; // Records a 'send' autograd function for this context with the provided // message id. void addSendFunction( const std::shared_ptr& func, int64_t autograd_message_id); // Records a 'recv' autograd function for this context with the provided // message id. void addRecvFunction( std::shared_ptr& func, int64_t autograd_message_id); // Given an autograd_message_id, retrieve the appropriate send function. std::shared_ptr retrieveSendFunction( int64_t autograd_message_id); // Return all send functions for this context. std::unordered_map> sendFunctions() const; // Return all recv functions for this context. std::unordered_map> recvFunctions() const; // Adds a future message recording an outstanding RPC. void addOutstandingRpc(const c10::intrusive_ptr& jitFuture); // Returns all gradients. const c10::Dict getGradients() const; // This function gives a mutable grad reference to the callback. // If the callback returns true, it means the grad in the context // needs to be updated. void runGradCallbackForVariable( const torch::autograd::Variable& variable, GradCallback&& cb); DistAutogradContext(const DistAutogradContext&) = delete; DistAutogradContext& operator=(const DistAutogradContext&) = delete; DistAutogradContext(DistAutogradContext&&) = delete; DistAutogradContext& operator=(DistAutogradContext&&) = delete; // records the workerID of a node that we sent an RPC to. // workerIDs are added here when we attach a send function to this autograd // context void addKnownWorkerId(const rpc::worker_id_t workerId); // Retrieves a set containing the known workerIds for this context // These are the different workers that this context has sent RPCs to. std::unordered_set getKnownWorkerIds() const; private: friend class BackwardPassCleanupGuard; friend class DistEngine; friend class RecvRpcBackward; friend class DistAccumulateGradCaptureHook; // Record that we would like to accumulate the provided gradient on the given // variable. void accumulateGrad( const torch::autograd::Variable& variable, const torch::Tensor& grad, size_t num_expected_refs); // Retrieve the GraphTask. std::shared_ptr retrieveGraphTask(); // Set the appropriate graph task for the backward pass. Can be called only // once. void setGraphTask(std::shared_ptr graphTask); // Resets the graph task to ensure we can run another distributed backward // pass for the same autograd context. void resetGraphTask(); // Waits for all outstanding RPCs for this context to finish and clears all // outstanding rpcs held in this context. This should be called only once. c10::intrusive_ptr clearAndWaitForOutstandingRpcsAsync(); void clearOutstandingRpcs(); // Record an event to mark the completion of gradient computation. These // events will later help to properly synchronize gradients consumptions // in getGradients(). We need these events because backward and // optimizer.step are separate RPC calls, and will occur on different CUDA // streams. Without synchronization, it is possible that gradients are // consumed before they are ready. void recordGradEvent(c10::Device device); const int64_t contextId_; // Set containing known worker IDs, used in cleaning up autograd context. // Whenever a sendRpcBackward is attached to the autograd graph for this // context, the destination is added here. std::unordered_set knownWorkerIds_; // Map from autograd_message_id to appropriate 'send' autograd function. std::unordered_map> sendAutogradFunctions_; // Map from autograd_message_id to appropriate 'recv' autograd function. std::unordered_map> recvAutogradFunctions_; // Gradients accumulated in this context so far. The key is the variable on // which the gradient needs to be accumulated and the value is the gradient // that needs to be accumulated on that variable.. c10::Dict accumulatedGrads_; // See comments for recordGradEvent(c10::Device device); std::unordered_map gradReadyEvents_; const c10::impl::VirtualGuardImpl impl_; // The autograd GraphTask for the backward pass on this node for this context. std::shared_ptr graphTask_; // List of futures for RPCs initiated by this node to propagate gradients to // other nodes. The distributed autograd engine on this node can return // successfully only if all these futures are done and are successful. std::vector> outStandingRpcs_; // Lock to protect concurrent modification of the context. mutable std::mutex lock_; }; using ContextPtr = std::shared_ptr; // This class stores a shared_ptr to a DistAutogradContext instance in a // thread local variable. The instance is given by the call site. The class // doesn't know the current context. It's just a util class. class TORCH_API ThreadLocalDistAutogradContext { public: // Store 'new_context' to the thread local variable maintained by this class. explicit ThreadLocalDistAutogradContext(ContextPtr&& new_context); ~ThreadLocalDistAutogradContext(); // Retrieve the stored DistAutogradContext instance. static ContextPtr getContextPtr(); private: ContextPtr prev_context_ptr_; }; } // namespace autograd } // namespace distributed } // namespace torch