1.. currentmodule:: torch.func 2 3.. _ux-limitations: 4 5UX Limitations 6============== 7 8torch.func, like `JAX <https://github.com/google/jax>`_, has restrictions around 9what can be transformed. In general, JAX’s limitations are that transforms 10only work with pure functions: that is, functions where the output is completely 11determined by the input and that do not involve side effects (like mutation). 12 13We have a similar guarantee: our transforms work well with pure functions. 14However, we do support certain in-place operations. On one hand, writing code 15compatible with function transforms may involve changing how you write PyTorch 16code, on the other hand, you may find that our transforms let you express things 17that were previously difficult to express in PyTorch. 18 19General limitations 20------------------- 21 22All torch.func transforms share a limitation in that a function should not 23assign to global variables. Instead, all outputs to a function must be returned 24from the function. This restriction comes from how torch.func is implemented: 25each transform wraps Tensor inputs in special torch.func Tensor subclasses 26that facilitate the transform. 27 28So, instead of the following: 29 30:: 31 32 import torch 33 from torch.func import grad 34 35 # Don't do this 36 intermediate = None 37 38 def f(x): 39 global intermediate 40 intermediate = x.sin() 41 z = intermediate.sin() 42 return z 43 44 x = torch.randn([]) 45 grad_x = grad(f)(x) 46 47Please rewrite ``f`` to return ``intermediate``: 48 49:: 50 51 def f(x): 52 intermediate = x.sin() 53 z = intermediate.sin() 54 return z, intermediate 55 56 grad_x, intermediate = grad(f, has_aux=True)(x) 57 58torch.autograd APIs 59------------------- 60 61If you are trying to use a ``torch.autograd`` API like ``torch.autograd.grad`` 62or ``torch.autograd.backward`` inside of a function being transformed by 63:func:`vmap` or one of torch.func's AD transforms (:func:`vjp`, :func:`jvp`, 64:func:`jacrev`, :func:`jacfwd`), the transform may not be able to transform over it. 65If it is unable to do so, you'll receive an error message. 66 67This is a fundamental design limitation in how PyTorch's AD support is implemented 68and the reason why we designed the torch.func library. Please instead use the torch.func 69equivalents of the ``torch.autograd`` APIs: 70- ``torch.autograd.grad``, ``Tensor.backward`` -> ``torch.func.vjp`` or ``torch.func.grad`` 71- ``torch.autograd.functional.jvp`` -> ``torch.func.jvp`` 72- ``torch.autograd.functional.jacobian`` -> ``torch.func.jacrev`` or ``torch.func.jacfwd`` 73- ``torch.autograd.functional.hessian`` -> ``torch.func.hessian`` 74 75vmap limitations 76---------------- 77 78.. note:: 79 :func:`vmap` is our most restrictive transform. 80 The grad-related transforms (:func:`grad`, :func:`vjp`, :func:`jvp`) do not 81 have these limitations. :func:`jacfwd` (and :func:`hessian`, which is 82 implemented with :func:`jacfwd`) is a composition of :func:`vmap` and 83 :func:`jvp` so it also has these limitations. 84 85``vmap(func)`` is a transform that returns a function that maps ``func`` over 86some new dimension of each input Tensor. The mental model for vmap is that it is 87like running a for-loop: for pure functions (i.e. in the absence of side 88effects), ``vmap(f)(x)`` is equivalent to: 89 90:: 91 92 torch.stack([f(x_i) for x_i in x.unbind(0)]) 93 94Mutation: Arbitrary mutation of Python data structures 95^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 96 97In the presence of side effects, :func:`vmap` no longer acts like it is running 98a for-loop. For example, the following function: 99 100:: 101 102 def f(x, list): 103 list.pop() 104 print("hello!") 105 return x.sum(0) 106 107 x = torch.randn(3, 1) 108 lst = [0, 1, 2, 3] 109 110 result = vmap(f, in_dims=(0, None))(x, lst) 111 112will print "hello!" once and pop only one element from ``lst``. 113 114 115:func:`vmap` executes ``f`` a single time, so all side effects only happen once. 116 117This is a consequence of how vmap is implemented. torch.func has a special, 118internal BatchedTensor class. ``vmap(f)(*inputs)`` takes all Tensor inputs, 119turns them into BatchedTensors, and calls ``f(*batched_tensor_inputs)``. 120BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized) 121behavior for each PyTorch operator. 122 123 124Mutation: in-place PyTorch Operations 125^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 126 127You might be here due to receiving an error about vmap-incompatible in-place 128operations. :func:`vmap` will raise an error if it encounters an unsupported PyTorch 129in-place operation and it will succeed otherwise. Unsupported operations 130are those that would cause a Tensor with more elements to be written to a 131Tensor with fewer elements. Here's an example of how this can occur: 132 133:: 134 135 def f(x, y): 136 x.add_(y) 137 return x 138 139 x = torch.randn(1) 140 y = torch.randn(3, 1) # When vmapped over, looks like it has shape [1] 141 142 # Raises an error because `x` has fewer elements than `y`. 143 vmap(f, in_dims=(None, 0))(x, y) 144 145``x`` is a Tensor with one element, ``y`` is a Tensor with three elements. 146``x + y`` has three elements (due to broadcasting), but attempting to write 147three elements back into ``x``, which only has one element, raises an error 148due to attempting to write three elements into a Tensor with a single element. 149 150There is no problem if the Tensor being written to is batched under 151:func:`~torch.vmap` (i.e. it is being vmapped over). 152 153:: 154 155 def f(x, y): 156 x.add_(y) 157 return x 158 159 x = torch.randn(3, 1) 160 y = torch.randn(3, 1) 161 expected = x + y 162 163 # Does not raise an error because x is being vmapped over. 164 vmap(f, in_dims=(0, 0))(x, y) 165 assert torch.allclose(x, expected) 166 167One common fix for this is to replace calls to factory functions with 168their "new_*" equivalent. For example: 169 170- Replace :func:`torch.zeros` with :meth:`Tensor.new_zeros` 171- Replace :func:`torch.empty` with :meth:`Tensor.new_empty` 172 173To see why this helps, consider the following. 174 175:: 176 177 def diag_embed(vec): 178 assert vec.dim() == 1 179 result = torch.zeros(vec.shape[0], vec.shape[0]) 180 result.diagonal().copy_(vec) 181 return result 182 183 vecs = torch.tensor([[0., 1, 2], [3., 4, 5]]) 184 185 # RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ... 186 vmap(diag_embed)(vecs) 187 188Inside of :func:`~torch.vmap`, ``result`` is a Tensor of shape [3, 3]. 189However, although ``vec`` looks like it has shape [3], ``vec`` actually has 190underlying shape [2, 3]. 191It is not possible to copy ``vec`` into ``result.diagonal()``, which has 192shape [3], because it has too many elements. 193 194:: 195 196 def diag_embed(vec): 197 assert vec.dim() == 1 198 result = vec.new_zeros(vec.shape[0], vec.shape[0]) 199 result.diagonal().copy_(vec) 200 return result 201 202 vecs = torch.tensor([[0., 1, 2], [3., 4, 5]]) 203 vmap(diag_embed)(vecs) 204 205Replacing :func:`torch.zeros` with :meth:`Tensor.new_zeros` makes it so that 206``result`` has an underlying Tensor of shape [2, 3, 3], so it is now possible 207to copy ``vec``, which has underlying shape [2, 3], into ``result.diagonal()``. 208 209 210Mutation: out= PyTorch Operations 211^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 212:func:`vmap` doesn't support the ``out=`` keyword argument in PyTorch operations. 213It will error out gracefully if it encounters that in your code. 214 215This is not a fundamental limitation; we could theoretically support this in the 216future but we have chosen not to for now. 217 218Data-dependent Python control flow 219^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 220We don't yet support ``vmap`` over data-dependent control flow. Data-dependent 221control flow is when the condition of an if-statement, while-loop, or 222for-loop is a Tensor that is being ``vmap``'ed over. For example, the 223following will raise an error message: 224 225:: 226 227 def relu(x): 228 if x > 0: 229 return x 230 return 0 231 232 x = torch.randn(3) 233 vmap(relu)(x) 234 235However, any control flow that is not dependent on the values in ``vmap``'ed 236tensors will work: 237 238:: 239 240 def custom_dot(x): 241 if x.dim() == 1: 242 return torch.dot(x, x) 243 return (x * x).sum() 244 245 x = torch.randn(3) 246 vmap(custom_dot)(x) 247 248JAX supports transforming over 249`data-dependent control flow <https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators>`_ 250using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loop``). 251We're investigating adding equivalents of those to PyTorch. 252 253Data-dependent operations (.item()) 254^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 255We do not (and will not) support vmap over a user-defined function that calls 256``.item()`` on a Tensor. For example, the following will raise an error message: 257 258:: 259 260 def f(x): 261 return x.item() 262 263 x = torch.randn(3) 264 vmap(f)(x) 265 266Please try to rewrite your code to not use ``.item()`` calls. 267 268You may also encounter an error message about using ``.item()`` but you might 269not have used it. In those cases, it is possible that PyTorch internally is 270calling ``.item()`` -- please file an issue on GitHub and we'll fix 271PyTorch internals. 272 273Dynamic shape operations (nonzero and friends) 274^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 275``vmap(f)`` requires that ``f`` applied to every "example" in your input 276returns a Tensor with the same shape. Operations such as ``torch.nonzero``, 277``torch.is_nonzero`` are not supported and will error as a result. 278 279To see why, consider the following example: 280 281:: 282 283 xs = torch.tensor([[0, 1, 2], [0, 0, 3]]) 284 vmap(torch.nonzero)(xs) 285 286``torch.nonzero(xs[0])`` returns a Tensor of shape 2; 287but ``torch.nonzero(xs[1])`` returns a Tensor of shape 1. 288We are unable to construct a single Tensor as an output; 289the output would need to be a ragged Tensor (and PyTorch does not yet have 290the concept of a ragged Tensor). 291 292 293Randomness 294---------- 295The user's intention when calling a random operation can be unclear. Specifically, some users may want 296the random behavior to be the same across batches while others may want it to differ across batches. 297To address this, ``vmap`` takes a randomness flag. 298 299The flag can only be passed to vmap and can take on 3 values, "error," "different," or "same," defaulting 300to error. Under "error" mode, any call to a random function will produce an error asking the user to use 301one of the other two flags based on their use case. 302 303Under "different" randomness, elements in a batch produce different random values. For instance, 304 305:: 306 307 def add_noise(x): 308 y = torch.randn(()) # y will be different across the batch 309 return x + y 310 311 x = torch.ones(3) 312 result = vmap(add_noise, randomness="different")(x) # we get 3 different values 313 314Under "same" randomness, elements in a batch produce same random values. For instance, 315 316:: 317 318 def add_noise(x): 319 y = torch.randn(()) # y will be the same across the batch 320 return x + y 321 322 x = torch.ones(3) 323 result = vmap(add_noise, randomness="same")(x) # we get the same value, repeated 3 times 324 325 326.. warning:: 327 Our system only determine the randomness behavior of PyTorch operators and cannot control the 328 behavior of other libraries, like numpy. This is similar to JAX's limitations with their solutions 329 330.. note:: 331 Multiple vmap calls using either type of supported randomness will not produce 332 the same results. Like with standard PyTorch, a user can get randomness reproducibility through 333 either using ``torch.manual_seed()`` outside of vmap or by using generators. 334 335.. note:: 336 Finally, our randomness differs from JAX because we aren't using a stateless PRNG, in part because PyTorch 337 doesn't have full support for a stateless PRNG. Instead, we've introduced a flag system to allow for the 338 most common forms of randomness that we see. If your use case does not fit these forms of randomness, please 339 file an issue. 340