• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1Patching Batch Norm
2===================
3
4What's happening?
5-----------------
6Batch Norm requires in-place updates to running_mean and running_var of the same size as the input.
7Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e.
8``regular.add_(batched)`` is not allowed). So when vmaping over a batch of inputs to a single module,
9we end up with this error
10
11How to fix
12----------
13All of these options assume that you don't need running stats. If you're using a module this means
14that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves
15running batch norm with vmap in evaluation mode, please file an issue
16
17Option 1: Change the BatchNorm
18^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
19If you've built the module yourself, you can change the module to not use running stats. In other
20words, anywhere that there's a BatchNorm module, set the ``track_running_stats`` flag to be False
21
22.. code-block:: python
23
24    BatchNorm2d(64, track_running_stats=False)
25
26
27Option 2: torchvision parameter
28^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
29Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are
30often defaulted to be BatchNorm2d if they've been defaulted. Instead you can set it to BatchNorm
31that doesn't use running stats
32
33.. code-block:: python
34
35    import torchvision
36    from functools import partial
37    torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
38
39Option 3: functorch's patching
40^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
41functorch has added some functionality to allow for quick, in-place patching of the module. If you
42have a net that you want to change, you can run ``replace_all_batch_norm_modules_`` to update the
43module in-place to not use running stats
44
45.. code-block:: python
46
47    from functorch.experimental import replace_all_batch_norm_modules_
48    replace_all_batch_norm_modules_(net)
49