File: test_fsdp_memory.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (236 lines) | stat: -rw-r--r-- 8,678 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# Owner(s): ["oncall: distributed"]

import sys

import torch
import torch.nn as nn
import torch.optim as optim
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
    FSDPTest,
)
from torch.testing._internal.common_utils import (
    TEST_WITH_DEV_DBG_ASAN,
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
)
from torch.utils.checkpoint import checkpoint


if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


def get_cur_mem(rank, result, prefix):
    """Collect memory allocated values in a result dict in MB"""
    result[prefix] = round(torch.cuda.memory_allocated() / 1024 / 1024)


class Model(nn.Module):
    def __init__(self, hidden_dim, with_fsdp=False, with_checkpoint=False):
        super().__init__()
        if with_fsdp:
            self.stem = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3),
                FSDP(nn.BatchNorm2d(64)),
                nn.ReLU(inplace=True),
            )
        else:
            self.stem = nn.Sequential(
                nn.Conv2d(3, 64, kernel_size=3),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True),
            )
        if with_fsdp:
            self.blocks = nn.Sequential(
                nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
                FSDP(nn.BatchNorm2d(hidden_dim)),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
                FSDP(nn.BatchNorm2d(hidden_dim)),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
                FSDP(nn.BatchNorm2d(hidden_dim)),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                nn.Flatten(),
            )
        else:
            self.blocks = nn.Sequential(
                nn.Conv2d(64, hidden_dim, kernel_size=5, padding=2),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=5, padding=2),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU(inplace=True),
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
                nn.Flatten(),
            )

        self.head = nn.Linear(hidden_dim, 10)
        self.with_checkpoint = with_checkpoint

    def forward(self, x):
        if self.with_checkpoint:
            return self.head(checkpoint(self.blocks, self.stem(x)))
        else:
            return self.head(self.blocks(self.stem(x)))


def create_model(with_fsdp, with_checkpoint, model_hidden_dim):
    torch.manual_seed(0)
    model = Model(model_hidden_dim, with_fsdp, with_checkpoint)
    if with_fsdp:
        model.stem = FSDP(model.stem)
        model.blocks = FSDP(model.blocks)
        model.head = FSDP(model.head)

    return model


class TestFSDPMemory(FSDPTest):
    @property
    def world_size(self):
        return 2

    def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations):
        gpu_id = self.rank
        world_size = self.world_size

        batch = torch.randn(size=(2, 3, 224, 224)).cuda()

        model = create_model(
            with_fsdp=True,
            with_checkpoint=with_checkpoint,
            model_hidden_dim=model_hidden_dim,
        )
        model = model.cuda()
        model = FSDP(model)

        # We enable momentum so that after the first iteration, the optimizer state is added
        # to the total memory used.
        criterion = nn.MSELoss()
        optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

        results = {}  # results of memory stats
        for iteration in range(iterations):
            get_cur_mem(gpu_id, results, f"iter {iteration}: start")

            out = model(batch)
            get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")

            out = sum(o.sum() for o in out[0])
            fake_loss = criterion(out, torch.tensor(0.0).cuda())
            get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")

            fake_loss.backward()
            get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")

            optimizer.step()
            get_cur_mem(gpu_id, results, f"iter {iteration}: after step")

            # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
            model.zero_grad(set_to_none=True)
            get_cur_mem(gpu_id, results, f"iter {iteration}: done")

        def cmp(results, expected):
            ret = ""
            self.assertEqual(results.keys(), expected.keys())
            for k, v in results.items():
                exp = expected[k]
                if abs(exp - v) > 1:  # allow 1MB rounding differences
                    ret += f"{k}: got {v}, expected {exp}\n"
            return ret

        output = cmp(results, expected)
        self.assertEqual(output, "")

    @skip_if_lt_x_gpu(2)
    @parametrize("ckpt", ["no_ckpt", "ckpt"])
    def test_fsdp_memory(self, ckpt):
        # hidden_dim 128: model size ~4MB
        model_hidden_dim = 128

        model = create_model(
            with_fsdp=False, with_checkpoint=False, model_hidden_dim=model_hidden_dim
        ).cuda()
        model_size_mb = round(torch.cuda.memory_allocated() / 1024 / 1024)
        del model

        sharded_model_size_mb = int(model_size_mb / self.world_size)

        # We have observed that sometimes after 3rd iteration, 4th one can fail (not on this
        # test but on much bigger scale tests). We run 4 iterations here just in case it happens.
        iterations = 4

        expected = {}

        for iteration in range(iterations):
            if iteration == 0:
                # sharded model size + 1MB temp memory
                expected[f"iter {iteration}: start"] = sharded_model_size_mb + 1
                # it is hard to calculate this memory size, get it from printed memory usage
                if ckpt == "ckpt":
                    expected[f"iter {iteration}: after fwd"] = 51
                    expected[f"iter {iteration}: after loss"] = 51
                else:
                    expected[f"iter {iteration}: after fwd"] = 340
                    expected[f"iter {iteration}: after loss"] = 340
                # sharded model size + sharded grad size + 1M temp memory
                expected[f"iter {iteration}: after bwd"] = 2 * sharded_model_size_mb + 1
            else:
                # after optimizer step in the first iteraiton, memory usage increased by
                # sharded_model_size_mb becasue of increased optimizer states memory usage
                expected[f"iter {iteration}: start"] = 2 * sharded_model_size_mb + 1
                if ckpt == "ckpt":
                    expected[f"iter {iteration}: after fwd"] = (
                        51 + sharded_model_size_mb
                    )
                    expected[f"iter {iteration}: after loss"] = (
                        51 + sharded_model_size_mb
                    )
                else:
                    expected[f"iter {iteration}: after fwd"] = (
                        340 + sharded_model_size_mb
                    )
                    expected[f"iter {iteration}: after loss"] = (
                        340 + sharded_model_size_mb
                    )
                expected[f"iter {iteration}: after bwd"] = 3 * sharded_model_size_mb + 1

            # sharded model size + sharded grad size + optimizer states + 1M temp memory
            expected[f"iter {iteration}: after step"] = 3 * sharded_model_size_mb + 1
            # grad memory is claimed after setting grad = None
            # sharded model size + optimizer states + 1M temp memory
            expected[f"iter {iteration}: done"] = 2 * sharded_model_size_mb + 1

        # Get the fsdp and checkpoint flags.
        with_ckpt = ckpt == "ckpt"

        self._dist_train(
            with_ckpt,
            expected,
            model_hidden_dim,
            iterations,
        )


instantiate_parametrized_tests(TestFSDPMemory)


if __name__ == "__main__":
    run_tests()