#pragma once #include #include namespace torch { namespace distributed { namespace autograd { using torch::autograd::variable_list; /// C++ API of Distributed Autograd that kicks off the distributed backward pass /// using the provided roots. This currently implements the /// :ref:`fast-mode-algorithm` which assumes all RPC messages sent in the same /// distributed autograd context across workers would be part of the autograd /// graph during the backward pass. /// /// We use the provided roots to discover the autograd graph and compute /// appropriate dependencies. This method blocks until the entire /// autograd computation is done. /// This function accumulates gradients in the leaves - you might need to zero /// them before calling it. /// /// \param context_id The autograd context id for which we should retrieve the /// gradients. /// \param roots Tensors which represent the roots of the autograd computation. /// All the tensors should be scalars. /// \param retain_graph If `false`, the graph used to compute the grad will be /// freed. Note that in nearly all cases setting this /// option to `true` is not needed and often can be worked /// around in a much more efficient way. Usually, you need /// to set this to `true` to run backward multiple times. TORCH_API void backward( int64_t context_id, const variable_list& roots, bool retain_graph = false); } // namespace autograd } // namespace distributed } // namespace torch