1Patching Batch Norm 2=================== 3 4What's happening? 5----------------- 6Batch Norm requires in-place updates to running_mean and running_var of the same size as the input. 7Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e. 8``regular.add_(batched)`` is not allowed). So when vmaping over a batch of inputs to a single module, 9we end up with this error 10 11How to fix 12---------- 13All of these options assume that you don't need running stats. If you're using a module this means 14that it's assumed you won't use batch norm in evaluation mode. If you have a use case that involves 15running batch norm with vmap in evaluation mode, please file an issue 16 17Option 1: Change the BatchNorm 18^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 19If you've built the module yourself, you can change the module to not use running stats. In other 20words, anywhere that there's a BatchNorm module, set the ``track_running_stats`` flag to be False 21 22.. code-block:: python 23 24 BatchNorm2d(64, track_running_stats=False) 25 26 27Option 2: torchvision parameter 28^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 29Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are 30often defaulted to be BatchNorm2d if they've been defaulted. Instead you can set it to BatchNorm 31that doesn't use running stats 32 33.. code-block:: python 34 35 import torchvision 36 from functools import partial 37 torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False)) 38 39Option 3: functorch's patching 40^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 41functorch has added some functionality to allow for quick, in-place patching of the module. If you 42have a net that you want to change, you can run ``replace_all_batch_norm_modules_`` to update the 43module in-place to not use running stats 44 45.. code-block:: python 46 47 from functorch.experimental import replace_all_batch_norm_modules_ 48 replace_all_batch_norm_modules_(net) 49