1torch.func 2========== 3 4.. currentmodule:: torch.func 5 6torch.func, previously known as "functorch", is 7`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch. 8 9.. note:: 10 This library is currently in `beta <https://pytorch.org/blog/pytorch-feature-classification-changes/#beta>`_. 11 What this means is that the features generally work (unless otherwise documented) 12 and we (the PyTorch team) are committed to bringing this library forward. However, the APIs 13 may change under user feedback and we don't have full coverage over PyTorch operations. 14 15 If you have suggestions on the API or use-cases you'd like to be covered, please 16 open an GitHub issue or reach out. We'd love to hear about how you're using the library. 17 18What are composable function transforms? 19---------------------------------------- 20 21- A "function transform" is a higher-order function that accepts a numerical function 22 and returns a new function that computes a different quantity. 23 24- :mod:`torch.func` has auto-differentiation transforms (``grad(f)`` returns a function that 25 computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)`` 26 returns a function that computes ``f`` over batches of inputs), and others. 27 28- These function transforms can compose with each other arbitrarily. For example, 29 composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that 30 stock PyTorch cannot efficiently compute today. 31 32Why composable function transforms? 33----------------------------------- 34 35There are a number of use cases that are tricky to do in PyTorch today: 36 37- computing per-sample-gradients (or other per-sample quantities) 38- running ensembles of models on a single machine 39- efficiently batching together tasks in the inner-loop of MAML 40- efficiently computing Jacobians and Hessians 41- efficiently computing batched Jacobians and Hessians 42 43Composing :func:`vmap`, :func:`grad`, and :func:`vjp` transforms allows us to express the above without designing a separate subsystem for each. 44This idea of composable function transforms comes from the `JAX framework <https://github.com/google/jax>`_. 45 46Read More 47--------- 48 49.. toctree:: 50 :maxdepth: 2 51 52 func.whirlwind_tour 53 func.api 54 func.ux_limitations 55 func.migrating 56