1Migrating from functorch to torch.func 2====================================== 3 4torch.func, previously known as "functorch", is 5`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch. 6 7functorch started as an out-of-tree library over at 8the `pytorch/functorch <https://github.com/pytorch/functorch>`_ repository. 9Our goal has always been to upstream functorch directly into PyTorch and provide 10it as a core PyTorch library. 11 12As the final step of the upstream, we've decided to migrate from being a top level package 13(``functorch``) to being a part of PyTorch to reflect how the function transforms are 14integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating 15``import functorch`` and ask that users migrate to the newest APIs, which we 16will maintain going forward. ``import functorch`` will be kept around to maintain 17backwards compatibility for a couple of releases. 18 19function transforms 20------------------- 21 22The following APIs are a drop-in replacement for the following 23`functorch APIs <https://pytorch.org/functorch/1.13/functorch.html>`_. 24They are fully backwards compatible. 25 26 27============================== ======================================= 28functorch API PyTorch API (as of PyTorch 2.0) 29============================== ======================================= 30functorch.vmap :func:`torch.vmap` or :func:`torch.func.vmap` 31functorch.grad :func:`torch.func.grad` 32functorch.vjp :func:`torch.func.vjp` 33functorch.jvp :func:`torch.func.jvp` 34functorch.jacrev :func:`torch.func.jacrev` 35functorch.jacfwd :func:`torch.func.jacfwd` 36functorch.hessian :func:`torch.func.hessian` 37functorch.functionalize :func:`torch.func.functionalize` 38============================== ======================================= 39 40Furthermore, if you are using torch.autograd.functional APIs, please try out 41the :mod:`torch.func` equivalents instead. :mod:`torch.func` function 42transforms are more composable and more performant in many cases. 43 44=========================================== ======================================= 45torch.autograd.functional API torch.func API (as of PyTorch 2.0) 46=========================================== ======================================= 47:func:`torch.autograd.functional.vjp` :func:`torch.func.grad` or :func:`torch.func.vjp` 48:func:`torch.autograd.functional.jvp` :func:`torch.func.jvp` 49:func:`torch.autograd.functional.jacobian` :func:`torch.func.jacrev` or :func:`torch.func.jacfwd` 50:func:`torch.autograd.functional.hessian` :func:`torch.func.hessian` 51=========================================== ======================================= 52 53NN module utilities 54------------------- 55 56We've changed the APIs to apply function transforms over NN modules to make them 57fit better into the PyTorch design philosophy. The new API is different, so 58please read this section carefully. 59 60functorch.make_functional 61^^^^^^^^^^^^^^^^^^^^^^^^^ 62 63:func:`torch.func.functional_call` is the replacement for 64`functorch.make_functional <https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional>`_ 65and 66`functorch.make_functional_with_buffers <https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers>`_. 67However, it is not a drop-in replacement. 68 69If you're in a hurry, you can use 70`helper functions in this gist <https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf>`_ 71that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers. 72We recommend using :func:`torch.func.functional_call` directly because it is a more explicit 73and flexible API. 74 75Concretely, functorch.make_functional returns a functional module and parameters. 76The functional module accepts parameters and inputs to the model as arguments. 77:func:`torch.func.functional_call` allows one to call the forward pass of an existing 78module using new parameters and buffers and inputs. 79 80Here's an example of how to compute gradients of parameters of a model using functorch 81vs :mod:`torch.func`:: 82 83 # --------------- 84 # using functorch 85 # --------------- 86 import torch 87 import functorch 88 inputs = torch.randn(64, 3) 89 targets = torch.randn(64, 3) 90 model = torch.nn.Linear(3, 3) 91 92 fmodel, params = functorch.make_functional(model) 93 94 def compute_loss(params, inputs, targets): 95 prediction = fmodel(params, inputs) 96 return torch.nn.functional.mse_loss(prediction, targets) 97 98 grads = functorch.grad(compute_loss)(params, inputs, targets) 99 100 # ------------------------------------ 101 # using torch.func (as of PyTorch 2.0) 102 # ------------------------------------ 103 import torch 104 inputs = torch.randn(64, 3) 105 targets = torch.randn(64, 3) 106 model = torch.nn.Linear(3, 3) 107 108 params = dict(model.named_parameters()) 109 110 def compute_loss(params, inputs, targets): 111 prediction = torch.func.functional_call(model, params, (inputs,)) 112 return torch.nn.functional.mse_loss(prediction, targets) 113 114 grads = torch.func.grad(compute_loss)(params, inputs, targets) 115 116And here's an example of how to compute jacobians of model parameters:: 117 118 # --------------- 119 # using functorch 120 # --------------- 121 import torch 122 import functorch 123 inputs = torch.randn(64, 3) 124 model = torch.nn.Linear(3, 3) 125 126 fmodel, params = functorch.make_functional(model) 127 jacobians = functorch.jacrev(fmodel)(params, inputs) 128 129 # ------------------------------------ 130 # using torch.func (as of PyTorch 2.0) 131 # ------------------------------------ 132 import torch 133 from torch.func import jacrev, functional_call 134 inputs = torch.randn(64, 3) 135 model = torch.nn.Linear(3, 3) 136 137 params = dict(model.named_parameters()) 138 # jacrev computes jacobians of argnums=0 by default. 139 # We set it to 1 to compute jacobians of params 140 jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,)) 141 142Note that it is important for memory consumption that you should only carry 143around a single copy of your parameters. ``model.named_parameters()`` does not copy 144the parameters. If in your model training you update the parameters of the model 145in-place, then the ``nn.Module`` that is your model has the single copy of the 146parameters and everything is OK. 147 148However, if you want to carry your parameters around in a dictionary and update 149them out-of-place, then there are two copies of parameters: the one in the 150dictionary and the one in the ``model``. In this case, you should change 151``model`` to not hold memory by converting it to the meta device via 152``model.to('meta')``. 153 154functorch.combine_state_for_ensemble 155^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 156 157Please use :func:`torch.func.stack_module_state` instead of 158`functorch.combine_state_for_ensemble <https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html>`_ 159:func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and 160one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call` 161for ensembling. 162 163For example, here is an example of how to ensemble over a very simple model:: 164 165 import torch 166 num_models = 5 167 batch_size = 64 168 in_features, out_features = 3, 3 169 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] 170 data = torch.randn(batch_size, 3) 171 172 # --------------- 173 # using functorch 174 # --------------- 175 import functorch 176 fmodel, params, buffers = functorch.combine_state_for_ensemble(models) 177 output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data) 178 assert output.shape == (num_models, batch_size, out_features) 179 180 # ------------------------------------ 181 # using torch.func (as of PyTorch 2.0) 182 # ------------------------------------ 183 import copy 184 185 # Construct a version of the model with no memory by putting the Tensors on 186 # the meta device. 187 base_model = copy.deepcopy(models[0]) 188 base_model.to('meta') 189 190 params, buffers = torch.func.stack_module_state(models) 191 192 # It is possible to vmap directly over torch.func.functional_call, 193 # but wrapping it in a function makes it clearer what is going on. 194 def call_single_model(params, buffers, data): 195 return torch.func.functional_call(base_model, (params, buffers), (data,)) 196 197 output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data) 198 assert output.shape == (num_models, batch_size, out_features) 199 200 201functorch.compile 202----------------- 203 204We are no longer supporting functorch.compile (also known as AOTAutograd) 205as a frontend for compilation in PyTorch; we have integrated AOTAutograd 206into PyTorch's compilation story. If you are a user, please use 207:func:`torch.compile` instead. 208