File: batch_norm.rst

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (48 lines) | stat: -rw-r--r-- 1,931 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
Patching Batch Norm
===================

What's happening?
-----------------
Batch Norm requires in-place updates to running_mean and running_var of the same size as the input.
Functorch does not support inplace update to a regular tensor that takes in a batched tensor (i.e.
``regular.add_(batched)`` is not allowed). So when vmaping over a batch of inputs to a single module,
we end up with this error

How to fix
----------
All of these options assume that you don't need running stats. If you're using a module this means
that it's assumed you won't use batch norm in evalution mode. If you have a use case that involves
running batch norm with vmap in evaluation mode, please file an issue

Option 1: Change the BatchNorm
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If you've built the module yourself, you can change the module to not use running stats. In other
words, anywhere that there's a BatchNorm module, set the ``track_running_stats`` flag to be False

.. code-block:: python

    BatchNorm2d(64, track_running_stats=False)


Option 2: torchvision parameter
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Some torchvision models, like resnet and regnet, can take in a ``norm_layer`` parameter. These are
often defaulted to be BatchNorm2d if they've been defaulted. Instead you can set it to BatchNorm
that doesn't use running stats

.. code-block:: python

    import torchvision
    from functools import partial
    torchvision.models.resnet18(norm_layer=partial(BatchNorm2d, track_running_stats=False))

Option 3: functorch's patching
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
functorch has added some functionality to allow for quick, in-place patching of the module. If you
have a net that you want to change, you can run ``replace_all_batch_norm_modules_`` to update the
module in-place to not use running stats

.. code-block:: python

    from functorch.experimental import replace_all_batch_norm_modules_
    replace_all_batch_norm_modules_(net)