1torch.func Whirlwind Tour 2========================= 3 4What is torch.func? 5------------------- 6 7.. currentmodule:: torch.func 8 9torch.func, previously known as functorch, is a library for 10`JAX <https://github.com/google/jax>`_-like composable function transforms in 11PyTorch. 12 13- A "function transform" is a higher-order function that accepts a numerical 14 function and returns a new function that computes a different quantity. 15- torch.func has auto-differentiation transforms (``grad(f)`` returns a function 16 that computes the gradient of ``f``), a vectorization/batching transform 17 (``vmap(f)`` returns a function that computes ``f`` over batches of inputs), 18 and others. 19- These function transforms can compose with each other arbitrarily. For 20 example, composing ``vmap(grad(f))`` computes a quantity called 21 per-sample-gradients that stock PyTorch cannot efficiently compute today. 22 23Why composable function transforms? 24----------------------------------- 25There are a number of use cases that are tricky to do in PyTorch today: 26- computing per-sample-gradients (or other per-sample quantities) 27 28- running ensembles of models on a single machine 29- efficiently batching together tasks in the inner-loop of MAML 30- efficiently computing Jacobians and Hessians 31- efficiently computing batched Jacobians and Hessians 32 33Composing :func:`vmap`, :func:`grad`, :func:`vjp`, and :func:`jvp` transforms 34allows us to express the above without designing a separate subsystem for each. 35 36What are the transforms? 37------------------------ 38 39:func:`grad` (gradient computation) 40^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 41 42``grad(func)`` is our gradient computation transform. It returns a new function 43that computes the gradients of ``func``. It assumes ``func`` returns a single-element 44Tensor and by default it computes the gradients of the output of ``func`` w.r.t. 45to the first input. 46 47.. code-block:: python 48 49 import torch 50 from torch.func import grad 51 x = torch.randn([]) 52 cos_x = grad(lambda x: torch.sin(x))(x) 53 assert torch.allclose(cos_x, x.cos()) 54 55 # Second-order gradients 56 neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) 57 assert torch.allclose(neg_sin_x, -x.sin()) 58 59:func:`vmap` (auto-vectorization) 60^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 61 62Note: :func:`vmap` imposes restrictions on the code that it can be used on. For more 63details, please see :ref:`ux-limitations`. 64 65``vmap(func)(*inputs)`` is a transform that adds a dimension to all Tensor 66operations in ``func``. ``vmap(func)`` returns a new function that maps ``func`` 67over some dimension (default: 0) of each Tensor in inputs. 68 69vmap is useful for hiding batch dimensions: one can write a function func that 70runs on examples and then lift it to a function that can take batches of 71examples with ``vmap(func)``, leading to a simpler modeling experience: 72 73.. code-block:: python 74 75 import torch 76 from torch.func import vmap 77 batch_size, feature_size = 3, 5 78 weights = torch.randn(feature_size, requires_grad=True) 79 80 def model(feature_vec): 81 # Very simple linear model with activation 82 assert feature_vec.dim() == 1 83 return feature_vec.dot(weights).relu() 84 85 examples = torch.randn(batch_size, feature_size) 86 result = vmap(model)(examples) 87 88When composed with :func:`grad`, :func:`vmap` can be used to compute per-sample-gradients: 89 90.. code-block:: python 91 92 from torch.func import vmap 93 batch_size, feature_size = 3, 5 94 95 def model(weights,feature_vec): 96 # Very simple linear model with activation 97 assert feature_vec.dim() == 1 98 return feature_vec.dot(weights).relu() 99 100 def compute_loss(weights, example, target): 101 y = model(weights, example) 102 return ((y - target) ** 2).mean() # MSELoss 103 104 weights = torch.randn(feature_size, requires_grad=True) 105 examples = torch.randn(batch_size, feature_size) 106 targets = torch.randn(batch_size) 107 inputs = (weights,examples, targets) 108 grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) 109 110:func:`vjp` (vector-Jacobian product) 111^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 112 113The :func:`vjp` transform applies ``func`` to ``inputs`` and returns a new function 114that computes the vector-Jacobian product (vjp) given some ``cotangents`` Tensors. 115 116.. code-block:: python 117 118 from torch.func import vjp 119 120 inputs = torch.randn(3) 121 func = torch.sin 122 cotangents = (torch.randn(3),) 123 124 outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents) 125 126:func:`jvp` (Jacobian-vector product) 127^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 128 129The :func:`jvp` transforms computes Jacobian-vector-products and is also known as 130"forward-mode AD". It is not a higher-order function unlike most other transforms, 131but it returns the outputs of ``func(inputs)`` as well as the jvps. 132 133.. code-block:: python 134 135 from torch.func import jvp 136 x = torch.randn(5) 137 y = torch.randn(5) 138 f = lambda x, y: (x * y) 139 _, out_tangent = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) 140 assert torch.allclose(out_tangent, x + y) 141 142:func:`jacrev`, :func:`jacfwd`, and :func:`hessian` 143^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 144 145The :func:`jacrev` transform returns a new function that takes in ``x`` and returns 146the Jacobian of the function with respect to ``x`` using reverse-mode AD. 147 148.. code-block:: python 149 150 from torch.func import jacrev 151 x = torch.randn(5) 152 jacobian = jacrev(torch.sin)(x) 153 expected = torch.diag(torch.cos(x)) 154 assert torch.allclose(jacobian, expected) 155 156:func:`jacrev` can be composed with :func:`vmap` to produce batched jacobians: 157 158.. code-block:: python 159 160 x = torch.randn(64, 5) 161 jacobian = vmap(jacrev(torch.sin))(x) 162 assert jacobian.shape == (64, 5, 5) 163 164:func:`jacfwd` is a drop-in replacement for jacrev that computes Jacobians using 165forward-mode AD: 166 167.. code-block:: python 168 169 from torch.func import jacfwd 170 x = torch.randn(5) 171 jacobian = jacfwd(torch.sin)(x) 172 expected = torch.diag(torch.cos(x)) 173 assert torch.allclose(jacobian, expected) 174 175Composing :func:`jacrev` with itself or :func:`jacfwd` can produce hessians: 176 177.. code-block:: python 178 179 def f(x): 180 return x.sin().sum() 181 182 x = torch.randn(5) 183 hessian0 = jacrev(jacrev(f))(x) 184 hessian1 = jacfwd(jacrev(f))(x) 185 186:func:`hessian` is a convenience function that combines jacfwd and jacrev: 187 188.. code-block:: python 189 190 from torch.func import hessian 191 192 def f(x): 193 return x.sin().sum() 194 195 x = torch.randn(5) 196 hess = hessian(f)(x) 197