• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1torch.func API Reference
2========================
3
4.. currentmodule:: torch.func
5
6.. automodule:: torch.func
7
8Function Transforms
9-------------------
10.. autosummary::
11    :toctree: generated
12    :nosignatures:
13
14     vmap
15     grad
16     grad_and_value
17     vjp
18     jvp
19     linearize
20     jacrev
21     jacfwd
22     hessian
23     functionalize
24
25Utilities for working with torch.nn.Modules
26-------------------------------------------
27
28In general, you can transform over a function that calls a ``torch.nn.Module``.
29For example, the following is an example of computing a jacobian of a function
30that takes three values and returns three values:
31
32.. code-block:: python
33
34    model = torch.nn.Linear(3, 3)
35
36    def f(x):
37        return model(x)
38
39    x = torch.randn(3)
40    jacobian = jacrev(f)(x)
41    assert jacobian.shape == (3, 3)
42
43However, if you want to do something like compute a jacobian over the parameters
44of the model, then there needs to be a way to construct a function where the
45parameters are the inputs to the function.
46That's what :func:`functional_call` is for:
47it accepts an nn.Module, the transformed ``parameters``, and the inputs to the
48Module's forward pass. It returns the value of running the Module's forward pass
49with the replaced parameters.
50
51Here's how we would compute the Jacobian over the parameters
52
53.. code-block:: python
54
55    model = torch.nn.Linear(3, 3)
56
57    def f(params, x):
58        return torch.func.functional_call(model, params, x)
59
60    x = torch.randn(3)
61    jacobian = jacrev(f)(dict(model.named_parameters()), x)
62
63
64.. autosummary::
65    :toctree: generated
66    :nosignatures:
67
68    functional_call
69    stack_module_state
70    replace_all_batch_norm_modules_
71
72If you're looking for information on fixing Batch Norm modules, please follow the
73guidance here
74
75.. toctree::
76   :maxdepth: 1
77
78   func.batch_norm
79