• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1Migrating from functorch to torch.func
2======================================
3
4torch.func, previously known as "functorch", is
5`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
6
7functorch started as an out-of-tree library over at
8the `pytorch/functorch <https://github.com/pytorch/functorch>`_ repository.
9Our goal has always been to upstream functorch directly into PyTorch and provide
10it as a core PyTorch library.
11
12As the final step of the upstream, we've decided to migrate from being a top level package
13(``functorch``) to being a part of PyTorch to reflect how the function transforms are
14integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating
15``import functorch`` and ask that users migrate to the newest APIs, which we
16will maintain going forward. ``import functorch`` will be kept around to maintain
17backwards compatibility for a couple of releases.
18
19function transforms
20-------------------
21
22The following APIs are a drop-in replacement for the following
23`functorch APIs <https://pytorch.org/functorch/1.13/functorch.html>`_.
24They are fully backwards compatible.
25
26
27==============================  =======================================
28functorch API                    PyTorch API (as of PyTorch 2.0)
29==============================  =======================================
30functorch.vmap                  :func:`torch.vmap` or :func:`torch.func.vmap`
31functorch.grad                  :func:`torch.func.grad`
32functorch.vjp                   :func:`torch.func.vjp`
33functorch.jvp                   :func:`torch.func.jvp`
34functorch.jacrev                :func:`torch.func.jacrev`
35functorch.jacfwd                :func:`torch.func.jacfwd`
36functorch.hessian               :func:`torch.func.hessian`
37functorch.functionalize         :func:`torch.func.functionalize`
38==============================  =======================================
39
40Furthermore, if you are using torch.autograd.functional APIs, please try out
41the :mod:`torch.func` equivalents instead. :mod:`torch.func` function
42transforms are more composable and more performant in many cases.
43
44=========================================== =======================================
45torch.autograd.functional API               torch.func API (as of PyTorch 2.0)
46=========================================== =======================================
47:func:`torch.autograd.functional.vjp`       :func:`torch.func.grad` or :func:`torch.func.vjp`
48:func:`torch.autograd.functional.jvp`       :func:`torch.func.jvp`
49:func:`torch.autograd.functional.jacobian`  :func:`torch.func.jacrev` or :func:`torch.func.jacfwd`
50:func:`torch.autograd.functional.hessian`   :func:`torch.func.hessian`
51=========================================== =======================================
52
53NN module utilities
54-------------------
55
56We've changed the APIs to apply function transforms over NN modules to make them
57fit better into the PyTorch design philosophy. The new API is different, so
58please read this section carefully.
59
60functorch.make_functional
61^^^^^^^^^^^^^^^^^^^^^^^^^
62
63:func:`torch.func.functional_call` is the replacement for
64`functorch.make_functional <https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional>`_
65and
66`functorch.make_functional_with_buffers <https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers>`_.
67However, it is not a drop-in replacement.
68
69If you're in a hurry, you can use
70`helper functions in this gist <https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf>`_
71that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers.
72We recommend using :func:`torch.func.functional_call` directly because it is a more explicit
73and flexible API.
74
75Concretely, functorch.make_functional returns a functional module and parameters.
76The functional module accepts parameters and inputs to the model as arguments.
77:func:`torch.func.functional_call` allows one to call the forward pass of an existing
78module using new parameters and buffers and inputs.
79
80Here's an example of how to compute gradients of parameters of a model using functorch
81vs :mod:`torch.func`::
82
83    # ---------------
84    # using functorch
85    # ---------------
86    import torch
87    import functorch
88    inputs = torch.randn(64, 3)
89    targets = torch.randn(64, 3)
90    model = torch.nn.Linear(3, 3)
91
92    fmodel, params = functorch.make_functional(model)
93
94    def compute_loss(params, inputs, targets):
95        prediction = fmodel(params, inputs)
96        return torch.nn.functional.mse_loss(prediction, targets)
97
98    grads = functorch.grad(compute_loss)(params, inputs, targets)
99
100    # ------------------------------------
101    # using torch.func (as of PyTorch 2.0)
102    # ------------------------------------
103    import torch
104    inputs = torch.randn(64, 3)
105    targets = torch.randn(64, 3)
106    model = torch.nn.Linear(3, 3)
107
108    params = dict(model.named_parameters())
109
110    def compute_loss(params, inputs, targets):
111        prediction = torch.func.functional_call(model, params, (inputs,))
112        return torch.nn.functional.mse_loss(prediction, targets)
113
114    grads = torch.func.grad(compute_loss)(params, inputs, targets)
115
116And here's an example of how to compute jacobians of model parameters::
117
118    # ---------------
119    # using functorch
120    # ---------------
121    import torch
122    import functorch
123    inputs = torch.randn(64, 3)
124    model = torch.nn.Linear(3, 3)
125
126    fmodel, params = functorch.make_functional(model)
127    jacobians = functorch.jacrev(fmodel)(params, inputs)
128
129    # ------------------------------------
130    # using torch.func (as of PyTorch 2.0)
131    # ------------------------------------
132    import torch
133    from torch.func import jacrev, functional_call
134    inputs = torch.randn(64, 3)
135    model = torch.nn.Linear(3, 3)
136
137    params = dict(model.named_parameters())
138    # jacrev computes jacobians of argnums=0 by default.
139    # We set it to 1 to compute jacobians of params
140    jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,))
141
142Note that it is important for memory consumption that you should only carry
143around a single copy of your parameters. ``model.named_parameters()`` does not copy
144the parameters. If in your model training you update the parameters of the model
145in-place, then the ``nn.Module`` that is your model has the single copy of the
146parameters and everything is OK.
147
148However, if you want to carry your parameters around in a dictionary and update
149them out-of-place, then there are two copies of parameters: the one in the
150dictionary and the one in the ``model``. In this case, you should change
151``model`` to not hold memory by converting it to the meta device via
152``model.to('meta')``.
153
154functorch.combine_state_for_ensemble
155^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
156
157Please use :func:`torch.func.stack_module_state` instead of
158`functorch.combine_state_for_ensemble <https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html>`_
159:func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and
160one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call`
161for ensembling.
162
163For example, here is an example of how to ensemble over a very simple model::
164
165    import torch
166    num_models = 5
167    batch_size = 64
168    in_features, out_features = 3, 3
169    models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
170    data = torch.randn(batch_size, 3)
171
172    # ---------------
173    # using functorch
174    # ---------------
175    import functorch
176    fmodel, params, buffers = functorch.combine_state_for_ensemble(models)
177    output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data)
178    assert output.shape == (num_models, batch_size, out_features)
179
180    # ------------------------------------
181    # using torch.func (as of PyTorch 2.0)
182    # ------------------------------------
183    import copy
184
185    # Construct a version of the model with no memory by putting the Tensors on
186    # the meta device.
187    base_model = copy.deepcopy(models[0])
188    base_model.to('meta')
189
190    params, buffers = torch.func.stack_module_state(models)
191
192    # It is possible to vmap directly over torch.func.functional_call,
193    # but wrapping it in a function makes it clearer what is going on.
194    def call_single_model(params, buffers, data):
195        return torch.func.functional_call(base_model, (params, buffers), (data,))
196
197    output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data)
198    assert output.shape == (num_models, batch_size, out_features)
199
200
201functorch.compile
202-----------------
203
204We are no longer supporting functorch.compile (also known as AOTAutograd)
205as a frontend for compilation in PyTorch; we have integrated AOTAutograd
206into PyTorch's compilation story. If you are a user, please use
207:func:`torch.compile` instead.
208