1torch.func API Reference 2======================== 3 4.. currentmodule:: torch.func 5 6.. automodule:: torch.func 7 8Function Transforms 9------------------- 10.. autosummary:: 11 :toctree: generated 12 :nosignatures: 13 14 vmap 15 grad 16 grad_and_value 17 vjp 18 jvp 19 linearize 20 jacrev 21 jacfwd 22 hessian 23 functionalize 24 25Utilities for working with torch.nn.Modules 26------------------------------------------- 27 28In general, you can transform over a function that calls a ``torch.nn.Module``. 29For example, the following is an example of computing a jacobian of a function 30that takes three values and returns three values: 31 32.. code-block:: python 33 34 model = torch.nn.Linear(3, 3) 35 36 def f(x): 37 return model(x) 38 39 x = torch.randn(3) 40 jacobian = jacrev(f)(x) 41 assert jacobian.shape == (3, 3) 42 43However, if you want to do something like compute a jacobian over the parameters 44of the model, then there needs to be a way to construct a function where the 45parameters are the inputs to the function. 46That's what :func:`functional_call` is for: 47it accepts an nn.Module, the transformed ``parameters``, and the inputs to the 48Module's forward pass. It returns the value of running the Module's forward pass 49with the replaced parameters. 50 51Here's how we would compute the Jacobian over the parameters 52 53.. code-block:: python 54 55 model = torch.nn.Linear(3, 3) 56 57 def f(params, x): 58 return torch.func.functional_call(model, params, x) 59 60 x = torch.randn(3) 61 jacobian = jacrev(f)(dict(model.named_parameters()), x) 62 63 64.. autosummary:: 65 :toctree: generated 66 :nosignatures: 67 68 functional_call 69 stack_module_state 70 replace_all_batch_norm_modules_ 71 72If you're looking for information on fixing Batch Norm modules, please follow the 73guidance here 74 75.. toctree:: 76 :maxdepth: 1 77 78 func.batch_norm 79