• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1.. currentmodule:: torch.func
2
3.. _ux-limitations:
4
5UX Limitations
6==============
7
8torch.func, like `JAX <https://github.com/google/jax>`_, has restrictions around
9what can be transformed. In general, JAX’s limitations are that transforms
10only work with pure functions: that is, functions where the output is completely
11determined by the input and that do not involve side effects (like mutation).
12
13We have a similar guarantee: our transforms work well with pure functions.
14However, we do support certain in-place operations. On one hand, writing code
15compatible with function transforms may involve changing how you write PyTorch
16code, on the other hand, you may find that our transforms let you express things
17that were previously difficult to express in PyTorch.
18
19General limitations
20-------------------
21
22All torch.func transforms share a limitation in that a function should not
23assign to global variables. Instead, all outputs to a function must be returned
24from the function. This restriction comes from how torch.func is implemented:
25each transform wraps Tensor inputs in special torch.func Tensor subclasses
26that facilitate the transform.
27
28So, instead of the following:
29
30::
31
32  import torch
33  from torch.func import grad
34
35  # Don't do this
36  intermediate = None
37
38  def f(x):
39    global intermediate
40    intermediate = x.sin()
41    z = intermediate.sin()
42    return z
43
44  x = torch.randn([])
45  grad_x = grad(f)(x)
46
47Please rewrite ``f`` to return ``intermediate``:
48
49::
50
51  def f(x):
52    intermediate = x.sin()
53    z = intermediate.sin()
54    return z, intermediate
55
56  grad_x, intermediate = grad(f, has_aux=True)(x)
57
58torch.autograd APIs
59-------------------
60
61If you are trying to use a ``torch.autograd`` API like ``torch.autograd.grad``
62or ``torch.autograd.backward`` inside of a function being transformed by
63:func:`vmap` or one of torch.func's AD transforms (:func:`vjp`, :func:`jvp`,
64:func:`jacrev`, :func:`jacfwd`), the transform may not be able to transform over it.
65If it is unable to do so, you'll receive an error message.
66
67This is a fundamental design limitation in how PyTorch's AD support is implemented
68and the reason why we designed the torch.func library. Please instead use the torch.func
69equivalents of the ``torch.autograd`` APIs:
70- ``torch.autograd.grad``, ``Tensor.backward`` -> ``torch.func.vjp`` or ``torch.func.grad``
71- ``torch.autograd.functional.jvp`` -> ``torch.func.jvp``
72- ``torch.autograd.functional.jacobian`` -> ``torch.func.jacrev`` or ``torch.func.jacfwd``
73- ``torch.autograd.functional.hessian`` -> ``torch.func.hessian``
74
75vmap limitations
76----------------
77
78.. note::
79  :func:`vmap` is our most restrictive transform.
80  The grad-related transforms (:func:`grad`, :func:`vjp`, :func:`jvp`) do not
81  have these limitations. :func:`jacfwd` (and :func:`hessian`, which is
82  implemented with :func:`jacfwd`) is a composition of :func:`vmap` and
83  :func:`jvp` so it also has these limitations.
84
85``vmap(func)`` is a transform that returns a function that maps ``func`` over
86some new dimension of each input Tensor. The mental model for vmap is that it is
87like running a for-loop: for pure functions (i.e. in the absence of side
88effects), ``vmap(f)(x)`` is equivalent to:
89
90::
91
92  torch.stack([f(x_i) for x_i in x.unbind(0)])
93
94Mutation: Arbitrary mutation of Python data structures
95^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
96
97In the presence of side effects, :func:`vmap` no longer acts like it is running
98a for-loop. For example, the following function:
99
100::
101
102  def f(x, list):
103    list.pop()
104    print("hello!")
105    return x.sum(0)
106
107  x = torch.randn(3, 1)
108  lst = [0, 1, 2, 3]
109
110  result = vmap(f, in_dims=(0, None))(x, lst)
111
112will print "hello!" once and pop only one element from ``lst``.
113
114
115:func:`vmap` executes ``f`` a single time, so all side effects only happen once.
116
117This is a consequence of how vmap is implemented. torch.func has a special,
118internal BatchedTensor class. ``vmap(f)(*inputs)`` takes all Tensor inputs,
119turns them into BatchedTensors, and calls ``f(*batched_tensor_inputs)``.
120BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized)
121behavior for each PyTorch operator.
122
123
124Mutation: in-place PyTorch Operations
125^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
126
127You might be here due to receiving an error about vmap-incompatible in-place
128operations. :func:`vmap` will raise an error if it encounters an unsupported PyTorch
129in-place operation and it will succeed otherwise. Unsupported operations
130are those that would cause a Tensor with more elements to be written to a
131Tensor with fewer elements. Here's an example of how this can occur:
132
133::
134
135  def f(x, y):
136    x.add_(y)
137    return x
138
139  x = torch.randn(1)
140  y = torch.randn(3, 1)  # When vmapped over, looks like it has shape [1]
141
142  # Raises an error because `x` has fewer elements than `y`.
143  vmap(f, in_dims=(None, 0))(x, y)
144
145``x`` is a Tensor with one element, ``y`` is a Tensor with three elements.
146``x + y`` has three elements (due to broadcasting), but attempting to write
147three elements back into ``x``, which only has one element, raises an error
148due to attempting to write three elements into a Tensor with a single element.
149
150There is no problem if the Tensor being written to is batched under
151:func:`~torch.vmap` (i.e. it is being vmapped over).
152
153::
154
155  def f(x, y):
156    x.add_(y)
157    return x
158
159  x = torch.randn(3, 1)
160  y = torch.randn(3, 1)
161  expected = x + y
162
163  # Does not raise an error because x is being vmapped over.
164  vmap(f, in_dims=(0, 0))(x, y)
165  assert torch.allclose(x, expected)
166
167One common fix for this is to replace calls to factory functions with
168their "new_*" equivalent. For example:
169
170- Replace :func:`torch.zeros` with :meth:`Tensor.new_zeros`
171- Replace :func:`torch.empty` with :meth:`Tensor.new_empty`
172
173To see why this helps, consider the following.
174
175::
176
177  def diag_embed(vec):
178    assert vec.dim() == 1
179    result = torch.zeros(vec.shape[0], vec.shape[0])
180    result.diagonal().copy_(vec)
181    return result
182
183  vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
184
185  # RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ...
186  vmap(diag_embed)(vecs)
187
188Inside of :func:`~torch.vmap`, ``result`` is a Tensor of shape [3, 3].
189However, although ``vec`` looks like it has shape [3], ``vec`` actually has
190underlying shape [2, 3].
191It is not possible to copy ``vec`` into ``result.diagonal()``, which has
192shape [3], because it has too many elements.
193
194::
195
196  def diag_embed(vec):
197    assert vec.dim() == 1
198    result = vec.new_zeros(vec.shape[0], vec.shape[0])
199    result.diagonal().copy_(vec)
200    return result
201
202  vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
203  vmap(diag_embed)(vecs)
204
205Replacing :func:`torch.zeros` with :meth:`Tensor.new_zeros` makes it so that
206``result`` has an underlying Tensor of shape [2, 3, 3], so it is now possible
207to copy ``vec``, which has underlying shape [2, 3], into ``result.diagonal()``.
208
209
210Mutation: out= PyTorch Operations
211^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
212:func:`vmap` doesn't support the ``out=`` keyword argument in PyTorch operations.
213It will error out gracefully if it encounters that in your code.
214
215This is not a fundamental limitation; we could theoretically support this in the
216future but we have chosen not to for now.
217
218Data-dependent Python control flow
219^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
220We don't yet support ``vmap`` over data-dependent control flow. Data-dependent
221control flow is when the condition of an if-statement, while-loop, or
222for-loop is a Tensor that is being ``vmap``'ed over. For example, the
223following will raise an error message:
224
225::
226
227  def relu(x):
228    if x > 0:
229      return x
230    return 0
231
232  x = torch.randn(3)
233  vmap(relu)(x)
234
235However, any control flow that is not dependent on the values in ``vmap``'ed
236tensors will work:
237
238::
239
240  def custom_dot(x):
241    if x.dim() == 1:
242      return torch.dot(x, x)
243    return (x * x).sum()
244
245  x = torch.randn(3)
246  vmap(custom_dot)(x)
247
248JAX supports transforming over
249`data-dependent control flow <https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators>`_
250using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loop``).
251We're investigating adding equivalents of those to PyTorch.
252
253Data-dependent operations (.item())
254^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
255We do not (and will not) support vmap over a user-defined function that calls
256``.item()`` on a Tensor. For example, the following will raise an error message:
257
258::
259
260  def f(x):
261    return x.item()
262
263  x = torch.randn(3)
264  vmap(f)(x)
265
266Please try to rewrite your code to not use ``.item()`` calls.
267
268You may also encounter an error message about using ``.item()`` but you might
269not have used it. In those cases, it is possible that PyTorch internally is
270calling ``.item()`` -- please file an issue on GitHub and we'll fix
271PyTorch internals.
272
273Dynamic shape operations (nonzero and friends)
274^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
275``vmap(f)`` requires that ``f`` applied to every "example" in your input
276returns a Tensor with the same shape. Operations such as ``torch.nonzero``,
277``torch.is_nonzero`` are not supported and will error as a result.
278
279To see why, consider the following example:
280
281::
282
283  xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
284  vmap(torch.nonzero)(xs)
285
286``torch.nonzero(xs[0])`` returns a Tensor of shape 2;
287but ``torch.nonzero(xs[1])`` returns a Tensor of shape 1.
288We are unable to construct a single Tensor as an output;
289the output would need to be a ragged Tensor (and PyTorch does not yet have
290the concept of a ragged Tensor).
291
292
293Randomness
294----------
295The user's intention when calling a random operation can be unclear. Specifically, some users may want
296the random behavior to be the same across batches while others may want it to differ across batches.
297To address this, ``vmap`` takes a randomness flag.
298
299The flag can only be passed to vmap and can take on 3 values, "error," "different," or "same," defaulting
300to error. Under "error" mode, any call to a random function will produce an error asking the user to use
301one of the other two flags based on their use case.
302
303Under "different" randomness, elements in a batch produce different random values. For instance,
304
305::
306
307  def add_noise(x):
308    y = torch.randn(())  # y will be different across the batch
309    return x + y
310
311  x = torch.ones(3)
312  result = vmap(add_noise, randomness="different")(x)  # we get 3 different values
313
314Under "same" randomness, elements in a batch produce same random values. For instance,
315
316::
317
318  def add_noise(x):
319    y = torch.randn(())  # y will be the same across the batch
320    return x + y
321
322  x = torch.ones(3)
323  result = vmap(add_noise, randomness="same")(x)  # we get the same value, repeated 3 times
324
325
326.. warning::
327    Our system only determine the randomness behavior of PyTorch operators and cannot control the
328    behavior of other libraries, like numpy. This is similar to JAX's limitations with their solutions
329
330.. note::
331    Multiple vmap calls using either type of supported randomness will not produce
332    the same results. Like with standard PyTorch, a user can get randomness reproducibility through
333    either using ``torch.manual_seed()`` outside of vmap or by using generators.
334
335.. note::
336    Finally, our randomness differs from JAX because we aren't using a stateless PRNG, in part because PyTorch
337    doesn't have full support for a stateless PRNG. Instead, we've introduced a flag system to allow for the
338    most common forms of randomness that we see. If your use case does not fit these forms of randomness, please
339    file an issue.
340