• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1Patching Batch Norm
2===================
3
4What's happening?
5-----------------
6Batch Norm requires in-place updates to running_mean and running_var of the same size as the input.
7Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e.
8``regular.add_(batched)`` is not allowed). So when vmapping over a batch of inputs to a single module,
9we end up with this error
10
11How to fix
12----------
13One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this
14
15All of these options assume that you don't need running stats. If you're using a module this means
16that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves
17running batch norm with vmap in evaluation mode, please file an issue
18
19Option 1: Change the BatchNorm
20^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
21If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with:
22
23.. code-block:: python
24
25    BatchNorm2d(C, G, track_running_stats=False)
26
27Here ``C`` is the same ``C`` as in the original BatchNorm. ``G`` is the number of groups to
28break ``C`` into. As such, ``C % G == 0`` and as a fallback, you can set ``C == G``, meaning
29each channel will be treated separately.
30
31If you must use BatchNorm and you've built the module yourself, you can change the module to
32not use running stats. In other words, anywhere that there's a BatchNorm module, set the
33``track_running_stats`` flag to be False
34
35.. code-block:: python
36
37    BatchNorm2d(64, track_running_stats=False)
38
39
40Option 2: torchvision parameter
41^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
42Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are
43often defaulted to be BatchNorm2d if they've been defaulted.
44
45Instead you can set it to be GroupNorm.
46
47.. code-block:: python
48
49    import torchvision
50    from functools import partial
51    torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c))
52
53Here, once again, ``c % g == 0`` so as a fallback, set ``g = c``.
54
55If you are attached to BatchNorm, be sure to use a version that doesn't use running stats
56
57.. code-block:: python
58
59    import torchvision
60    from functools import partial
61    torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))
62
63Option 3: functorch's patching
64^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
65functorch has added some functionality to allow for quick, in-place patching of the module to not
66use running stats. Changing the norm layer is more fragile, so we have not offered that. If you
67have a net where you want the BatchNorm to not use running stats, you can run
68``replace_all_batch_norm_modules_`` to update the module in-place to not use running stats
69
70.. code-block:: python
71
72    from torch.func import replace_all_batch_norm_modules_
73    replace_all_batch_norm_modules_(net)
74
75Option 4: eval mode
76^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
77When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode
78
79.. code-block:: python
80
81    model.eval()
82    vmap(model)(x)
83    model.train()
84