• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1torch.func Whirlwind Tour
2=========================
3
4What is torch.func?
5-------------------
6
7.. currentmodule:: torch.func
8
9torch.func, previously known as functorch, is a library for
10`JAX <https://github.com/google/jax>`_-like composable function transforms in
11PyTorch.
12
13- A "function transform" is a higher-order function that accepts a numerical
14  function and returns a new function that computes a different quantity.
15- torch.func has auto-differentiation transforms (``grad(f)`` returns a function
16  that computes the gradient of ``f``), a vectorization/batching transform
17  (``vmap(f)`` returns a function that computes ``f`` over batches of inputs),
18  and others.
19- These function transforms can compose with each other arbitrarily. For
20  example, composing ``vmap(grad(f))`` computes a quantity called
21  per-sample-gradients that stock PyTorch cannot efficiently compute today.
22
23Why composable function transforms?
24-----------------------------------
25There are a number of use cases that are tricky to do in PyTorch today:
26- computing per-sample-gradients (or other per-sample quantities)
27
28- running ensembles of models on a single machine
29- efficiently batching together tasks in the inner-loop of MAML
30- efficiently computing Jacobians and Hessians
31- efficiently computing batched Jacobians and Hessians
32
33Composing :func:`vmap`, :func:`grad`, :func:`vjp`, and :func:`jvp` transforms
34allows us to express the above without designing a separate subsystem for each.
35
36What are the transforms?
37------------------------
38
39:func:`grad` (gradient computation)
40^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
41
42``grad(func)`` is our gradient computation transform. It returns a new function
43that computes the gradients of ``func``. It assumes ``func`` returns a single-element
44Tensor and by default it computes the gradients of the output of ``func`` w.r.t.
45to the first input.
46
47.. code-block:: python
48
49    import torch
50    from torch.func import grad
51    x = torch.randn([])
52    cos_x = grad(lambda x: torch.sin(x))(x)
53    assert torch.allclose(cos_x, x.cos())
54
55    # Second-order gradients
56    neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
57    assert torch.allclose(neg_sin_x, -x.sin())
58
59:func:`vmap` (auto-vectorization)
60^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
61
62Note: :func:`vmap` imposes restrictions on the code that it can be used on. For more
63details, please see :ref:`ux-limitations`.
64
65``vmap(func)(*inputs)`` is a transform that adds a dimension to all Tensor
66operations in ``func``. ``vmap(func)`` returns a new function that maps ``func``
67over some dimension (default: 0) of each Tensor in inputs.
68
69vmap is useful for hiding batch dimensions: one can write a function func that
70runs on examples and then lift it to a function that can take batches of
71examples with ``vmap(func)``, leading to a simpler modeling experience:
72
73.. code-block:: python
74
75    import torch
76    from torch.func import vmap
77    batch_size, feature_size = 3, 5
78    weights = torch.randn(feature_size, requires_grad=True)
79
80    def model(feature_vec):
81        # Very simple linear model with activation
82        assert feature_vec.dim() == 1
83        return feature_vec.dot(weights).relu()
84
85    examples = torch.randn(batch_size, feature_size)
86    result = vmap(model)(examples)
87
88When composed with :func:`grad`, :func:`vmap` can be used to compute per-sample-gradients:
89
90.. code-block:: python
91
92    from torch.func import vmap
93    batch_size, feature_size = 3, 5
94
95    def model(weights,feature_vec):
96        # Very simple linear model with activation
97        assert feature_vec.dim() == 1
98        return feature_vec.dot(weights).relu()
99
100    def compute_loss(weights, example, target):
101        y = model(weights, example)
102        return ((y - target) ** 2).mean()  # MSELoss
103
104    weights = torch.randn(feature_size, requires_grad=True)
105    examples = torch.randn(batch_size, feature_size)
106    targets = torch.randn(batch_size)
107    inputs = (weights,examples, targets)
108    grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
109
110:func:`vjp` (vector-Jacobian product)
111^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
112
113The :func:`vjp` transform applies ``func`` to ``inputs`` and returns a new function
114that computes the vector-Jacobian product (vjp) given some ``cotangents`` Tensors.
115
116.. code-block:: python
117
118    from torch.func import vjp
119
120    inputs = torch.randn(3)
121    func = torch.sin
122    cotangents = (torch.randn(3),)
123
124    outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)
125
126:func:`jvp` (Jacobian-vector product)
127^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
128
129The :func:`jvp` transforms computes Jacobian-vector-products and is also known as
130"forward-mode AD". It is not a higher-order function unlike most other transforms,
131but it returns the outputs of ``func(inputs)`` as well as the jvps.
132
133.. code-block:: python
134
135    from torch.func import jvp
136    x = torch.randn(5)
137    y = torch.randn(5)
138    f = lambda x, y: (x * y)
139    _, out_tangent = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))
140    assert torch.allclose(out_tangent, x + y)
141
142:func:`jacrev`, :func:`jacfwd`, and :func:`hessian`
143^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
144
145The :func:`jacrev` transform returns a new function that takes in ``x`` and returns
146the Jacobian of the function with respect to ``x`` using reverse-mode AD.
147
148.. code-block:: python
149
150    from torch.func import jacrev
151    x = torch.randn(5)
152    jacobian = jacrev(torch.sin)(x)
153    expected = torch.diag(torch.cos(x))
154    assert torch.allclose(jacobian, expected)
155
156:func:`jacrev` can be composed with :func:`vmap` to produce batched jacobians:
157
158.. code-block:: python
159
160    x = torch.randn(64, 5)
161    jacobian = vmap(jacrev(torch.sin))(x)
162    assert jacobian.shape == (64, 5, 5)
163
164:func:`jacfwd` is a drop-in replacement for jacrev that computes Jacobians using
165forward-mode AD:
166
167.. code-block:: python
168
169    from torch.func import jacfwd
170    x = torch.randn(5)
171    jacobian = jacfwd(torch.sin)(x)
172    expected = torch.diag(torch.cos(x))
173    assert torch.allclose(jacobian, expected)
174
175Composing :func:`jacrev` with itself or :func:`jacfwd` can produce hessians:
176
177.. code-block:: python
178
179    def f(x):
180        return x.sin().sum()
181
182    x = torch.randn(5)
183    hessian0 = jacrev(jacrev(f))(x)
184    hessian1 = jacfwd(jacrev(f))(x)
185
186:func:`hessian` is a convenience function that combines jacfwd and jacrev:
187
188.. code-block:: python
189
190    from torch.func import hessian
191
192    def f(x):
193        return x.sin().sum()
194
195    x = torch.randn(5)
196    hess = hessian(f)(x)
197