• Home
  • Raw
  • Download

Lines Matching full:autograd

1 .. _func-autograd-function:
3 Extending torch.func with autograd.Function
6 .. currentmodule:: torch.autograd
8 So you'd like to use :class:`torch.autograd.Function` with the :mod:`torch.func`
14 have it work with function transforms. That is, the :class:`torch.autograd.Function`'s
19 PyTorch combines both of these concepts into :class:`torch.autograd.Function`.
24 This guide assumes you are familiar with :ref:`extending-autograd`,
25 which explains how to use :class:`torch.autograd.Function`.
27 :class:`torch.autograd.Function` can either have a :meth:`~Function.forward` that accepts a ctx obj…
51 the :class:`torch.autograd.Function` needs a :meth:`~Function.backward` staticmethod.
52 - to support :func:`torch.vmap`, the :class:`torch.autograd.Function` needs a :meth:`~Function.vmap…
53 - to support :func:`torch.func.jvp`, the :class:`torch.autograd.Function` needs a :meth:`~Function.…
58 In order for the :class:`torch.autograd.Function` to be arbitrarily composable with function
61 operators or call other :class:`torch.autograd.Function` (that may call into C++/CUDA/etc).
65 Example 1: autograd.Function calls into another system
68 A common case is a :class:`torch.autograd.Function` with both forward() and backward() calling
79 class NumpySort(torch.autograd.Function):
120 # For the autograd.Function to be arbitrarily composable with function
123 # only consist of PyTorch operations or autograd.Function.
129 # autograd.Function, NumpyTake.
133 class NumpyTake(torch.autograd.Function):
170 Example 2: autograd.Function specifies custom gradient rules
173 Another common case is an :class:`torch.autograd.Function` that is implemented with PyTorch
183 Here's an example of an :class:`torch.autograd.Function` for the function ``y = x ** 3`` where we
189 class MyCube(torch.autograd.Function):
194 # pass computes dx = 3 * x ** 2. In this autograd.Function, we've done
208 # In order for the autograd.Function to work with higher-order
231 Please read these limitations of :class:`torch.autograd.Function` with torch.func transforms
237 :class:`torch.autograd.Function`. The way to be completely safe is to ensure that the only
238 Tensors being used inside any method of the :class:`torch.autograd.Function` must be directly
240 the :class:`torch.autograd.Function`.
242 :class:`torch.autograd.Function` does not handle Tensors in pytrees (arbitrary nested
244 those Tensors to be tracked by autograd, they must be passed directly as
245 an argument to :class:`torch.autograd.Function`. This is in contrast to
248 Please only use :meth:`~torch.autograd.function.FunctionCtx.save_for_backward` or
249 :meth:`~torch.autograd.function.FunctionCtx.save_for_forward` to save Tensors.
257 To use an :class:`torch.autograd.Function` with :func:`torch.vmap`, you must either:
259 …th:`~Function.vmap` staticmethod that tells us the behavior of the :class:`torch.autograd.Function`
266 If your :class:`torch.autograd.Function` fulfills the following additional constraints, then we
276 - The :class:`torch.autograd.Function`'s :meth:`~Function.forward`, :meth:`~Function.backward` (if …
283 class MyCube(torch.autograd.Function):
318 If your :class:`torch.autograd.Function` calls into another system (like NumPy, C++, CUDA, triton),
323 to add a :meth:`~Function.vmap` staticmethod to all of your :class:`torch.autograd.Function`:
329 We do recommend ensuring all of your :class:`torch.autograd.Function` have support for
331 :class:`torch.autograd.Function` to work with all combinations of :func:`torch.func` transforms.
365 class NumpySort(torch.autograd.Function):
428 class NumpyTake(torch.autograd.Function):
487 entire :class:`~torch.autograd.Function`. That is, (pseudocode) ``grad(vmap(MyFunc))``
490 If your autograd.Function has any custom behavior in the backward pass, please
496 :class:`~torch.autograd.Function` that PyTorch is able to generate a vmap
503 To support forward-mode AD, a :class:`torch.autograd.Function` must have a :meth:`~Function.jvp` st…
504 Please see :ref:`forward-ad-autograd-function` for details.