File: test_deferred_batch_norm.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (194 lines) | stat: -rw-r--r-- 5,432 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Owner(s): ["oncall: distributed"]

# Copyright 2019 Kakao Brain
#
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from copy import deepcopy
from itertools import chain

import pytest
import torch
from torch import nn, optim

from torch.distributed.pipeline.sync.batchnorm import DeferredBatchNorm

CHUNKS = 4


def tilt_dist(input):
    # Tilt variance by channel.
    rgb = input.transpose(0, 1)
    rgb[0] *= 1
    rgb[1] *= 10
    rgb[2] *= 100

    # Tilt mean by single batch.
    for i, single in enumerate(input):
        single += 2 ** i

    return input


def chunked_forward(model, input, chunks=CHUNKS):
    output_chunks = []

    for chunk in input.chunk(chunks):
        output_chunks.append(model(chunk))

    return torch.cat(output_chunks)


@pytest.mark.parametrize("chunks", [1, 4])
@pytest.mark.parametrize("input_requires_grad", [True, False])
def test_transparency(chunks, input_requires_grad):
    bn = nn.BatchNorm2d(3)
    dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=chunks)

    input1 = torch.rand(16, 3, 224, 224)
    input1 = tilt_dist(input1)
    input2 = input1.clone()
    input1.requires_grad = input_requires_grad
    input2.requires_grad = input_requires_grad

    output1 = chunked_forward(bn, input1, chunks=chunks)
    output2 = chunked_forward(dbn, input2, chunks=chunks)

    assert torch.allclose(output1, output2, atol=1e-4)

    output1.mean().backward()
    output2.mean().backward()

    assert torch.allclose(bn.weight.grad, dbn.weight.grad, atol=1e-4)

    if input_requires_grad:
        assert input1.grad is not None
        assert input2.grad is not None
        assert torch.allclose(input1.grad, input2.grad, atol=1e-4)


@pytest.mark.parametrize("momentum", [0.1, None])
def test_running_stats(momentum):
    bn = nn.BatchNorm2d(3, momentum=momentum)
    dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

    input = torch.rand(16, 3, 224, 224)
    input = tilt_dist(input)

    bn(input)
    chunked_forward(dbn, input)

    assert torch.allclose(bn.running_mean, dbn.running_mean, atol=1e-4)
    assert torch.allclose(bn.running_var, dbn.running_var, atol=1e-4)


def test_convert_deferred_batch_norm():
    bn = nn.BatchNorm2d(3, track_running_stats=False)
    bn = DeferredBatchNorm.convert_deferred_batch_norm(bn, chunks=CHUNKS)
    assert type(bn) is nn.BatchNorm2d  # because of track_running_stats=False

    dbn = DeferredBatchNorm(3, chunks=CHUNKS)
    dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS)
    assert dbn is dbn_again

    dbn_again = DeferredBatchNorm.convert_deferred_batch_norm(dbn, chunks=CHUNKS + 1)
    assert dbn is not dbn_again  # because of different chunks


def test_eval():
    bn = nn.BatchNorm2d(3)
    dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

    input = torch.rand(16, 3, 224, 224)
    input = tilt_dist(input)

    bn(input)
    chunked_forward(dbn, input)

    bn.eval()
    dbn.eval()

    assert torch.allclose(bn(input), dbn(input), atol=1e-4)


def test_optimize():
    bn = nn.BatchNorm2d(3)
    dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

    opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=1.0)

    for i in range(5):
        input = torch.rand(16, 3, 224, 224)
        input = tilt_dist(input)

        # train
        y = bn(input)
        a = y.sum()
        a.backward()

        y = chunked_forward(dbn, input)
        b = y.sum()
        b.backward()

        opt.step()

        # eval
        bn.eval()
        dbn.eval()

        with torch.no_grad():
            assert torch.allclose(bn(input), dbn(input), atol=1e-1 * (10 ** i))


def test_conv_bn():
    bn = nn.Sequential(nn.Conv2d(3, 3, 1), nn.BatchNorm2d(3))
    dbn = DeferredBatchNorm.convert_deferred_batch_norm(deepcopy(bn), chunks=CHUNKS)

    input = torch.rand(16, 3, 224, 224)
    input = tilt_dist(input)

    opt = optim.SGD(chain(bn.parameters(), dbn.parameters()), lr=0.1)

    # 1st step
    a = bn(input)
    b = chunked_forward(dbn, input)

    # Outputs are different. (per-mini-batch vs. per-micro-batch)
    assert not torch.allclose(a, b)

    a.sum().backward()
    b.sum().backward()
    opt.step()
    opt.zero_grad()

    # Conv layers are also trained differently because of their different outputs.
    assert not torch.allclose(bn[0].weight, dbn[0].weight)

    # But BNs track identical running stats.
    assert torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
    assert torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3)

    # 2nd step
    a = bn(input)
    b = chunked_forward(dbn, input)
    a.sum().backward()
    b.sum().backward()

    # BNs can't track identical running stats due to the different conv layers.
    assert not torch.allclose(bn[1].running_mean, dbn[1].running_mean, atol=1e-4)
    assert not torch.allclose(bn[1].running_var, dbn[1].running_var, atol=1e3)


def test_input_requiring_grad():
    dbn = DeferredBatchNorm(3, chunks=CHUNKS)

    input = torch.rand(16, 3, 224, 224)
    input = tilt_dist(input)
    input.requires_grad = True

    chunked_forward(dbn, input)

    assert not dbn.sum.requires_grad
    assert dbn.sum.grad_fn is None