1Patching Batch Norm 2=================== 3 4What's happening? 5----------------- 6Batch Norm requires in-place updates to running_mean and running_var of the same size as the input. 7Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e. 8``regular.add_(batched)`` is not allowed). So when vmapping over a batch of inputs to a single module, 9we end up with this error 10 11How to fix 12---------- 13One of the best supported ways is to switch BatchNorm for GroupNorm. Options 1 and 2 support this 14 15All of these options assume that you don't need running stats. If you're using a module this means 16that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves 17running batch norm with vmap in evaluation mode, please file an issue 18 19Option 1: Change the BatchNorm 20^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 21If you want to change for GroupNorm, anywhere that you have BatchNorm, replace it with: 22 23.. code-block:: python 24 25 BatchNorm2d(C, G, track_running_stats=False) 26 27Here ``C`` is the same ``C`` as in the original BatchNorm. ``G`` is the number of groups to 28break ``C`` into. As such, ``C % G == 0`` and as a fallback, you can set ``C == G``, meaning 29each channel will be treated separately. 30 31If you must use BatchNorm and you've built the module yourself, you can change the module to 32not use running stats. In other words, anywhere that there's a BatchNorm module, set the 33``track_running_stats`` flag to be False 34 35.. code-block:: python 36 37 BatchNorm2d(64, track_running_stats=False) 38 39 40Option 2: torchvision parameter 41^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 42Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are 43often defaulted to be BatchNorm2d if they've been defaulted. 44 45Instead you can set it to be GroupNorm. 46 47.. code-block:: python 48 49 import torchvision 50 from functools import partial 51 torchvision.models.resnet18(norm_layer=lambda c: GroupNorm(num_groups=g, c)) 52 53Here, once again, ``c % g == 0`` so as a fallback, set ``g = c``. 54 55If you are attached to BatchNorm, be sure to use a version that doesn't use running stats 56 57.. code-block:: python 58 59 import torchvision 60 from functools import partial 61 torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False)) 62 63Option 3: functorch's patching 64^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 65functorch has added some functionality to allow for quick, in-place patching of the module to not 66use running stats. Changing the norm layer is more fragile, so we have not offered that. If you 67have a net where you want the BatchNorm to not use running stats, you can run 68``replace_all_batch_norm_modules_`` to update the module in-place to not use running stats 69 70.. code-block:: python 71 72 from torch.func import replace_all_batch_norm_modules_ 73 replace_all_batch_norm_modules_(net) 74 75Option 4: eval mode 76^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 77When run under eval mode, the running_mean and running_var will not be updated. Therefore, vmap can support this mode 78 79.. code-block:: python 80 81 model.eval() 82 vmap(model)(x) 83 model.train() 84