• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1.. currentmodule:: functorch
2
3UX Limitations
4==============
5
6functorch, like `JAX <https://github.com/google/jax>`_, has restrictions around
7what can be transformed. In general, JAX’s limitations are that transforms
8only work with pure functions: that is, functions where the output is completely
9determined by the input and that do not involve side effects (like mutation).
10
11We have a similar guarantee: our transforms work well with pure functions.
12However, we do support certain in-place operations. On one hand, writing code
13compatible with functorch transforms may involve changing how you write PyTorch
14code, on the other hand, you may find that our transforms let you express things
15that were previously difficult to express in PyTorch.
16
17General limitations
18-------------------
19
20All functorch transforms share a limitation in that a function should not
21assign to global variables. Instead, all outputs to a function must be returned
22from the function. This restriction comes from how functorch is implemented:
23each transform wraps Tensor inputs in special functorch Tensor subclasses
24that facilitate the transform.
25
26So, instead of the following:
27
28::
29
30  import torch
31  from functorch import grad
32
33  # Don't do this
34  intermediate = None
35
36  def f(x):
37    global intermediate
38    intermediate = x.sin()
39    z = intermediate.sin()
40    return z
41
42  x = torch.randn([])
43  grad_x = grad(f)(x)
44
45Please rewrite ``f`` to return ``intermediate``:
46
47::
48
49  def f(x):
50    intermediate = x.sin()
51    z = intermediate.sin()
52    return z, intermediate
53
54  grad_x, intermediate = grad(f, has_aux=True)(x)
55
56torch.autograd APIs
57-------------------
58
59If you are trying to use a ``torch.autograd`` API like ``torch.autograd.grad``
60or ``torch.autograd.backward`` inside of a function being transformed by
61:func:`vmap` or one of functorch's AD transforms (:func:`vjp`, :func:`jvp`,
62:func:`jacrev`, :func:`jacfwd`), the transform may not be able to transform over it.
63If it is unable to do so, you'll receive an error message.
64
65This is a fundamental design limitation in how PyTorch's AD support is implemented
66and the reason why we designed the functorch library. Please instead use the functorch
67equivalents of the ``torch.autograd`` APIs:
68- ``torch.autograd.grad``, ``Tensor.backward`` -> ``functorch.vjp`` or ``functorch.grad``
69- ``torch.autograd.functional.jvp`` -> ``functorch.jvp``
70- ``torch.autograd.functional.jacobian`` -> ``functorch.jacrev`` or ``functorch.jacfwd``
71- ``torch.autograd.functional.hessian`` -> ``functorch.hessian``
72
73vmap limitations
74----------------
75
76.. note::
77  :func:`vmap` is our most restrictive transform.
78  The grad-related transforms (:func:`grad`, :func:`vjp`, :func:`jvp`) do not
79  have these limitations. :func:`jacfwd` (and :func:`hessian`, which is
80  implemented with :func:`jacfwd`) is a composition of :func:`vmap` and
81  :func:`jvp` so it also has these limitations.
82
83``vmap(func)`` is a transform that returns a function that maps ``func`` over
84some new dimension of each input Tensor. The mental model for vmap is that it is
85like running a for-loop: for pure functions (i.e. in the absence of side
86effects), ``vmap(f)(x)`` is equivalent to:
87
88::
89
90  torch.stack([f(x_i) for x_i in x.unbind(0)])
91
92Mutation: Arbitrary mutation of Python data structures
93^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
94
95In the presence of side effects, :func:`vmap` no longer acts like it is running
96a for-loop. For example, the following function:
97
98::
99
100  def f(x, list):
101    list.pop()
102    print("hello!")
103    return x.sum(0)
104
105  x = torch.randn(3, 1)
106  lst = [0, 1, 2, 3]
107
108  result = vmap(f, in_dims=(0, None))(x, lst)
109
110will print "hello!" once and pop only one element from ``lst``.
111
112
113:func:`vmap` executes `f` a single time, so all side effects only happen once.
114
115This is a consequence of how vmap is implemented. functorch has a special,
116internal BatchedTensor class. ``vmap(f)(*inputs)`` takes all Tensor inputs,
117turns them into BatchedTensors, and calls ``f(*batched_tensor_inputs)``.
118BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized)
119behavior for each PyTorch operator.
120
121
122Mutation: in-place PyTorch Operations
123^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
124
125:func:`vmap` will raise an error if it encounters an unsupported PyTorch
126in-place operation and it will succeed otherwise. Unsupported operations
127are those that would cause a Tensor with more elements to be written to a
128Tensor with fewer elements. Here's an example of how this can occur:
129
130::
131
132  def f(x, y):
133    x.add_(y)
134    return x
135
136  x = torch.randn(1)
137  y = torch.randn(3)
138
139  # Raises an error because `y` has fewer elements than `x`.
140  vmap(f, in_dims=(None, 0))(x, y)
141
142``x`` is a Tensor with one element, ``y`` is a Tensor with three elements.
143``x + y`` has three elements (due to broadcasting), but attempting to write
144three elements back into ``x``, which only has one element, raises an error
145due to attempting to write three elements into a Tensor with a single element.
146
147There is no problem if the Tensor being written to has the same number of
148elements (or more):
149
150::
151
152  def f(x, y):
153    x.add_(y)
154    return x
155
156  x = torch.randn(3)
157  y = torch.randn(3)
158  expected = x + y
159
160  # Does not raise an error because x and y have the same number of elements.
161  vmap(f, in_dims=(0, 0))(x, y)
162  assert torch.allclose(x, expected)
163
164Mutation: out= PyTorch Operations
165^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
166:func:`vmap` doesn't support the ``out=`` keyword argument in PyTorch operations.
167It will error out gracefully if it encounters that in your code.
168
169This is not a fundamental limitation; we could theoretically support this in the
170future but we have chosen not to for now.
171
172Data-dependent Python control flow
173^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
174We don't yet support ``vmap`` over data-dependent control flow. Data-dependent
175control flow is when the condition of an if-statement, while-loop, or
176for-loop is a Tensor that is being ``vmap``'ed over. For example, the
177following will raise an error message:
178
179::
180
181  def relu(x):
182    if x > 0:
183      return x
184    return 0
185
186  x = torch.randn(3)
187  vmap(relu)(x)
188
189However, any control flow that is not dependent on the values in ``vmap``'ed
190tensors will work:
191
192::
193
194  def custom_dot(x):
195    if x.dim() == 1:
196      return torch.dot(x, x)
197    return (x * x).sum()
198
199  x = torch.randn(3)
200  vmap(custom_dot)(x)
201
202JAX supports transforming over
203`data-dependent control flow <https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators>`_
204using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loop``).
205We're investigating adding equivalents of those to functorch
206(open an issue on `GitHub <https://github.com/pytorch/functorch>`_ to voice your support!).
207
208Data-dependent operations (.item())
209^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
210We do not (and will not) support vmap over a user-defined function that calls
211``.item()`` on a Tensor. For example, the following will raise an error message:
212
213::
214
215  def f(x):
216    return x.item()
217
218  x = torch.randn(3)
219  vmap(f)(x)
220
221Please try to rewrite your code to not use ``.item()`` calls.
222
223You may also encounter an error message about using ``.item()`` but you might
224not have used it. In those cases, it is possible that PyTorch internally is
225calling ``.item()`` -- please file an issue on GitHub and we'll fix
226PyTorch internals.
227
228Dynamic shape operations (nonzero and friends)
229^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
230``vmap(f)`` requires that ``f`` applied to every "example" in your input
231returns a Tensor with the same shape. Operations such as ``torch.nonzero``,
232``torch.is_nonzero`` are not supported and will error as a result.
233
234To see why, consider the following example:
235
236::
237
238  xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
239  vmap(torch.nonzero)(xs)
240
241``torch.nonzero(xs[0])`` returns a Tensor of shape 2;
242but ``torch.nonzero(xs[1])`` returns a Tensor of shape 1.
243We are unable to construct a single Tensor as an output;
244the output would need to be a ragged Tensor (and PyTorch does not yet have
245the concept of a ragged Tensor).
246
247
248Randomness
249----------
250The user's intention when calling a random operation can be unclear. Specifically, some users may want
251the random behavior to be the same across batches while others may want it to differ across batches.
252To address this, ``vmap`` takes a randomness flag.
253
254The flag can only be passed to vmap and can take on 3 values, "error," "different," or "same," defaulting
255to error. Under "error" mode, any call to a random function will produce an error asking the user to use
256one of the other two flags based on their use case.
257
258Under "different" randomness, elements in a batch produce different random values. For instance,
259
260::
261
262  def add_noise(x):
263    y = torch.randn(())  # y will be different across the batch
264    return x + y
265
266  x = torch.ones(3)
267  result = vmap(add_noise, randomness="different")(x)  # we get 3 different values
268
269Under "same" randomness, elements in a batch produce same random values. For instance,
270
271::
272
273  def add_noise(x):
274    y = torch.randn(())  # y will be the same across the batch
275    return x + y
276
277  x = torch.ones(3)
278  result = vmap(add_noise, randomness="same")(x)  # we get the same value, repeated 3 times
279
280
281.. warning::
282    Our system only determine the randomness behavior of PyTorch operators and cannot control the
283    behavior of other libraries, like numpy. This is similar to JAX's limitations with their solutions
284
285.. note::
286    Multiple vmap calls using either type of supported randomness will not produce
287    the same results. Like with standard PyTorch, a user can get randomness reproducibility through
288    either using ``torch.manual_seed()`` outside of vmap or by using generators.
289
290.. note::
291    Finally, our randomness differs from JAX because we aren't using a stateless PRNG, in part because PyTorch
292    doesn't have full support for a stateless PRNG. Instead, we've introduced a flag system to allow for the
293    most common forms of randomness that we see. If your use case does not fit these forms of randomness, please
294    file an issue.
295