1.. currentmodule:: functorch 2 3UX Limitations 4============== 5 6functorch, like `JAX <https://github.com/google/jax>`_, has restrictions around 7what can be transformed. In general, JAX’s limitations are that transforms 8only work with pure functions: that is, functions where the output is completely 9determined by the input and that do not involve side effects (like mutation). 10 11We have a similar guarantee: our transforms work well with pure functions. 12However, we do support certain in-place operations. On one hand, writing code 13compatible with functorch transforms may involve changing how you write PyTorch 14code, on the other hand, you may find that our transforms let you express things 15that were previously difficult to express in PyTorch. 16 17General limitations 18------------------- 19 20All functorch transforms share a limitation in that a function should not 21assign to global variables. Instead, all outputs to a function must be returned 22from the function. This restriction comes from how functorch is implemented: 23each transform wraps Tensor inputs in special functorch Tensor subclasses 24that facilitate the transform. 25 26So, instead of the following: 27 28:: 29 30 import torch 31 from functorch import grad 32 33 # Don't do this 34 intermediate = None 35 36 def f(x): 37 global intermediate 38 intermediate = x.sin() 39 z = intermediate.sin() 40 return z 41 42 x = torch.randn([]) 43 grad_x = grad(f)(x) 44 45Please rewrite ``f`` to return ``intermediate``: 46 47:: 48 49 def f(x): 50 intermediate = x.sin() 51 z = intermediate.sin() 52 return z, intermediate 53 54 grad_x, intermediate = grad(f, has_aux=True)(x) 55 56torch.autograd APIs 57------------------- 58 59If you are trying to use a ``torch.autograd`` API like ``torch.autograd.grad`` 60or ``torch.autograd.backward`` inside of a function being transformed by 61:func:`vmap` or one of functorch's AD transforms (:func:`vjp`, :func:`jvp`, 62:func:`jacrev`, :func:`jacfwd`), the transform may not be able to transform over it. 63If it is unable to do so, you'll receive an error message. 64 65This is a fundamental design limitation in how PyTorch's AD support is implemented 66and the reason why we designed the functorch library. Please instead use the functorch 67equivalents of the ``torch.autograd`` APIs: 68- ``torch.autograd.grad``, ``Tensor.backward`` -> ``functorch.vjp`` or ``functorch.grad`` 69- ``torch.autograd.functional.jvp`` -> ``functorch.jvp`` 70- ``torch.autograd.functional.jacobian`` -> ``functorch.jacrev`` or ``functorch.jacfwd`` 71- ``torch.autograd.functional.hessian`` -> ``functorch.hessian`` 72 73vmap limitations 74---------------- 75 76.. note:: 77 :func:`vmap` is our most restrictive transform. 78 The grad-related transforms (:func:`grad`, :func:`vjp`, :func:`jvp`) do not 79 have these limitations. :func:`jacfwd` (and :func:`hessian`, which is 80 implemented with :func:`jacfwd`) is a composition of :func:`vmap` and 81 :func:`jvp` so it also has these limitations. 82 83``vmap(func)`` is a transform that returns a function that maps ``func`` over 84some new dimension of each input Tensor. The mental model for vmap is that it is 85like running a for-loop: for pure functions (i.e. in the absence of side 86effects), ``vmap(f)(x)`` is equivalent to: 87 88:: 89 90 torch.stack([f(x_i) for x_i in x.unbind(0)]) 91 92Mutation: Arbitrary mutation of Python data structures 93^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 94 95In the presence of side effects, :func:`vmap` no longer acts like it is running 96a for-loop. For example, the following function: 97 98:: 99 100 def f(x, list): 101 list.pop() 102 print("hello!") 103 return x.sum(0) 104 105 x = torch.randn(3, 1) 106 lst = [0, 1, 2, 3] 107 108 result = vmap(f, in_dims=(0, None))(x, lst) 109 110will print "hello!" once and pop only one element from ``lst``. 111 112 113:func:`vmap` executes `f` a single time, so all side effects only happen once. 114 115This is a consequence of how vmap is implemented. functorch has a special, 116internal BatchedTensor class. ``vmap(f)(*inputs)`` takes all Tensor inputs, 117turns them into BatchedTensors, and calls ``f(*batched_tensor_inputs)``. 118BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized) 119behavior for each PyTorch operator. 120 121 122Mutation: in-place PyTorch Operations 123^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 124 125:func:`vmap` will raise an error if it encounters an unsupported PyTorch 126in-place operation and it will succeed otherwise. Unsupported operations 127are those that would cause a Tensor with more elements to be written to a 128Tensor with fewer elements. Here's an example of how this can occur: 129 130:: 131 132 def f(x, y): 133 x.add_(y) 134 return x 135 136 x = torch.randn(1) 137 y = torch.randn(3) 138 139 # Raises an error because `y` has fewer elements than `x`. 140 vmap(f, in_dims=(None, 0))(x, y) 141 142``x`` is a Tensor with one element, ``y`` is a Tensor with three elements. 143``x + y`` has three elements (due to broadcasting), but attempting to write 144three elements back into ``x``, which only has one element, raises an error 145due to attempting to write three elements into a Tensor with a single element. 146 147There is no problem if the Tensor being written to has the same number of 148elements (or more): 149 150:: 151 152 def f(x, y): 153 x.add_(y) 154 return x 155 156 x = torch.randn(3) 157 y = torch.randn(3) 158 expected = x + y 159 160 # Does not raise an error because x and y have the same number of elements. 161 vmap(f, in_dims=(0, 0))(x, y) 162 assert torch.allclose(x, expected) 163 164Mutation: out= PyTorch Operations 165^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 166:func:`vmap` doesn't support the ``out=`` keyword argument in PyTorch operations. 167It will error out gracefully if it encounters that in your code. 168 169This is not a fundamental limitation; we could theoretically support this in the 170future but we have chosen not to for now. 171 172Data-dependent Python control flow 173^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 174We don't yet support ``vmap`` over data-dependent control flow. Data-dependent 175control flow is when the condition of an if-statement, while-loop, or 176for-loop is a Tensor that is being ``vmap``'ed over. For example, the 177following will raise an error message: 178 179:: 180 181 def relu(x): 182 if x > 0: 183 return x 184 return 0 185 186 x = torch.randn(3) 187 vmap(relu)(x) 188 189However, any control flow that is not dependent on the values in ``vmap``'ed 190tensors will work: 191 192:: 193 194 def custom_dot(x): 195 if x.dim() == 1: 196 return torch.dot(x, x) 197 return (x * x).sum() 198 199 x = torch.randn(3) 200 vmap(custom_dot)(x) 201 202JAX supports transforming over 203`data-dependent control flow <https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators>`_ 204using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loop``). 205We're investigating adding equivalents of those to functorch 206(open an issue on `GitHub <https://github.com/pytorch/functorch>`_ to voice your support!). 207 208Data-dependent operations (.item()) 209^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 210We do not (and will not) support vmap over a user-defined function that calls 211``.item()`` on a Tensor. For example, the following will raise an error message: 212 213:: 214 215 def f(x): 216 return x.item() 217 218 x = torch.randn(3) 219 vmap(f)(x) 220 221Please try to rewrite your code to not use ``.item()`` calls. 222 223You may also encounter an error message about using ``.item()`` but you might 224not have used it. In those cases, it is possible that PyTorch internally is 225calling ``.item()`` -- please file an issue on GitHub and we'll fix 226PyTorch internals. 227 228Dynamic shape operations (nonzero and friends) 229^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 230``vmap(f)`` requires that ``f`` applied to every "example" in your input 231returns a Tensor with the same shape. Operations such as ``torch.nonzero``, 232``torch.is_nonzero`` are not supported and will error as a result. 233 234To see why, consider the following example: 235 236:: 237 238 xs = torch.tensor([[0, 1, 2], [0, 0, 3]]) 239 vmap(torch.nonzero)(xs) 240 241``torch.nonzero(xs[0])`` returns a Tensor of shape 2; 242but ``torch.nonzero(xs[1])`` returns a Tensor of shape 1. 243We are unable to construct a single Tensor as an output; 244the output would need to be a ragged Tensor (and PyTorch does not yet have 245the concept of a ragged Tensor). 246 247 248Randomness 249---------- 250The user's intention when calling a random operation can be unclear. Specifically, some users may want 251the random behavior to be the same across batches while others may want it to differ across batches. 252To address this, ``vmap`` takes a randomness flag. 253 254The flag can only be passed to vmap and can take on 3 values, "error," "different," or "same," defaulting 255to error. Under "error" mode, any call to a random function will produce an error asking the user to use 256one of the other two flags based on their use case. 257 258Under "different" randomness, elements in a batch produce different random values. For instance, 259 260:: 261 262 def add_noise(x): 263 y = torch.randn(()) # y will be different across the batch 264 return x + y 265 266 x = torch.ones(3) 267 result = vmap(add_noise, randomness="different")(x) # we get 3 different values 268 269Under "same" randomness, elements in a batch produce same random values. For instance, 270 271:: 272 273 def add_noise(x): 274 y = torch.randn(()) # y will be the same across the batch 275 return x + y 276 277 x = torch.ones(3) 278 result = vmap(add_noise, randomness="same")(x) # we get the same value, repeated 3 times 279 280 281.. warning:: 282 Our system only determine the randomness behavior of PyTorch operators and cannot control the 283 behavior of other libraries, like numpy. This is similar to JAX's limitations with their solutions 284 285.. note:: 286 Multiple vmap calls using either type of supported randomness will not produce 287 the same results. Like with standard PyTorch, a user can get randomness reproducibility through 288 either using ``torch.manual_seed()`` outside of vmap or by using generators. 289 290.. note:: 291 Finally, our randomness differs from JAX because we aren't using a stateless PRNG, in part because PyTorch 292 doesn't have full support for a stateless PRNG. Instead, we've introduced a flag system to allow for the 293 most common forms of randomness that we see. If your use case does not fit these forms of randomness, please 294 file an issue. 295