File: test_fsdp_meta.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (328 lines) | stat: -rw-r--r-- 11,451 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# Owner(s): ["oncall: distributed"]

import sys

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import always_wrap_policy as always_wrap
from torch.distributed.fsdp.wrap import wrap, enable_wrap
from torch.testing._internal.common_fsdp import (
    FSDPTest,
)
from torch.testing._internal.common_utils import (
    TEST_WITH_DEV_DBG_ASAN,
    run_tests,
    parametrize,
    instantiate_parametrized_tests,
    sandcastle_skip_if,
)
from torch.testing._internal.common_distributed import (
    skip_if_lt_x_gpu,
)

_TORCHDISTX_AVAIL = True
try:
    from torchdistx import deferred_init
except ImportError:
    _TORCHDISTX_AVAIL = False


if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


def _reset_params_if_meta(is_meta, model):
    # For torchdistX init, we don't need to call reset_params, as
    # deferred_init(model).materialize() is equivalent to model().
    if is_meta:
        model.reset_parameters()

class MyLinear(nn.Linear):
    """
    Linear layer with deterministic reset_parameters for testing.
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def reset_parameters(self, *args, **kwargs):
        with torch.no_grad():
            self.weight.fill_(1)

class MyModel(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.lin1 = MyLinear(2, 2, bias=False, device=device)
        self.lin2 = MyLinear(2, 2, bias=False, device=device)

    def forward(self, x):
        return self.lin2(self.lin1(x))

    def reset_parameters(self, *args, **kwargs):
        for m in [self.lin1, self.lin2]:
            if not isinstance(m, FSDP):
                m.reset_parameters()


class NestedModel(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.lin1 = MyLinear(2, 2, bias=False, device=device)
        self.lin1 = wrap(self.lin1)
        self.lin2 = MyLinear(2, 2, bias=False, device=device)
        self.l3 = MyModel(device=device)
        self.l3 = wrap(self.l3)

    def forward(self, x):
        return self.l3(self.lin2(self.lin1(x)))

    def reset_parameters(self):
        for m in [self.lin1, self.lin2, self.l3]:
            if not isinstance(m, FSDP):
                m.reset_parameters()

def _init_with_reset_params(module):
    """
    to_empty + reset_parameters() init function example for modules
    initailized with device="meta"
    """
    is_meta = any(t.is_meta for t in module.parameters())
    if is_meta:
        module.to_empty(device=torch.cuda.current_device())
    with torch.no_grad():
        module.reset_parameters()

def _init_with_torchdistX(module):
    """
    torchdistX-based deferred module initialization function example
    using ``materialize_module``.
    """
    assert _TORCHDISTX_AVAIL

    def check_fn(k):
        return not isinstance(k, FSDP)

    deferred_init.materialize_module(module, check_fn=check_fn)

class TestFSDPWithMetaDevice(FSDPTest):
    @property
    def world_size(self):
        return 2

    @property
    def process_group(self):
        return dist.distributed_c10d._get_default_group()

    def _compare_fsdp(self, fsdp1, fsdp2):
        with FSDP.summon_full_params(fsdp1):
            with FSDP.summon_full_params(fsdp2):
                for p1, p2 in zip(fsdp1.parameters(), fsdp2.parameters()):
                    self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")

    def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None):
        # Create model on meta device and wrap with FSDP.
        model = meta_module_fn()
        is_meta = next(model.parameters()).is_meta
        fsdp_meta = FSDP(
            model,
            auto_wrap_policy=always_wrap,
            param_init_fn=init_fn,
        )

        meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)

        # Test to make sure it is the same model parameters as regular FSDP
        # approach.
        regular = MyModel(device="cuda")
        _reset_params_if_meta(is_meta, regular)
        fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap)
        regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)

        self._compare_fsdp(fsdp_meta, fsdp_regular)
        inp = torch.randn(10, 2, device='cuda')
        fsdp_meta(inp).sum().backward()
        fsdp_regular(inp).sum().backward()
        meta_opt.step()
        regular_opt.step()
        self._compare_fsdp(fsdp_meta, fsdp_regular)

        # Test that meta init works if all submodules are contained in only a
        # single FSDP unit.
        model = meta_module_fn()
        fsdp_meta = FSDP(model, param_init_fn=init_fn)
        meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
        regular = MyModel(device="cuda")
        _reset_params_if_meta(is_meta, regular)
        fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap)
        regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)

        # Run a forward + backward pass + optimizer step
        fsdp_meta(inp).sum().backward()
        fsdp_regular(inp).sum().backward()
        meta_opt.step()
        regular_opt.step()
        self._compare_fsdp(fsdp_meta, fsdp_regular)

    @skip_if_lt_x_gpu(2)
    def test_simple_model_with_meta_device_reset_params(self):
        def meta_module_fn():
            return MyModel(device="meta")
        self._test_simple_model_with_meta_device(
            meta_module_fn, _init_with_reset_params
        )

    @skip_if_lt_x_gpu(2)
    def test_simple_model_with_meta_device_default_init(self):
        def meta_module_fn():
            return MyModel(device="meta")
        self._test_simple_model_with_meta_device(meta_module_fn)

    @skip_if_lt_x_gpu(2)
    @sandcastle_skip_if(
        not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX"
    )
    def test_simple_model_with_torchdistX_default_init(self):
        def meta_module_fn():
            return deferred_init.deferred_init(MyModel, device="cuda")

        self._test_simple_model_with_meta_device(meta_module_fn)

    @skip_if_lt_x_gpu(2)
    @sandcastle_skip_if(
        not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX"
    )
    def test_simple_model_with_torchdistX_init_fn(self):
        def meta_module_fn():
            return deferred_init.deferred_init(MyModel, device="cuda")

        self._test_simple_model_with_meta_device(meta_module_fn, init_fn=_init_with_torchdistX)

    def _test_nested_model_with_meta_device(self, auto_wrap, meta_module_fn, init_fn=None):
        if auto_wrap:
            module = meta_module_fn()
            is_meta = next(module.parameters()).is_meta
            fsdp_meta = FSDP(
                module,
                auto_wrap_policy=always_wrap,
                param_init_fn=init_fn,
            )
            meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
            module_regular = NestedModel(device="cuda")
            _reset_params_if_meta(is_meta, module_regular)
            fsdp_regular = FSDP(
                module_regular,
                auto_wrap_policy=always_wrap,
            )
            regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
        else:
            with enable_wrap(
                wrapper_cls=FSDP, param_init_fn=init_fn,
            ):
                module = meta_module_fn()
                is_meta = next(module.parameters()).is_meta
                # Non FSDP modules will still be initialized because they bubble up
                # to be part of a larger FSDP unit.
                fsdp_meta = wrap(module)
                meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)

            # Init and reset parameters before wrapping so that reset_params
            # matches up with meta device's initialization.
            module_regular = NestedModel(device="cuda")
            _reset_params_if_meta(is_meta, module_regular)
            with enable_wrap(wrapper_cls=FSDP):
                module_regular.lin1 = wrap(module_regular.lin1)
                module_regular.l3 = wrap(module_regular.l3)
                fsdp_regular = wrap(module_regular)
                regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)

        # Compare it before training
        self._compare_fsdp(fsdp_meta, fsdp_regular)
        inp = torch.randn(10, 2, device='cuda')
        fsdp_meta(inp).sum().backward()
        fsdp_regular(inp).sum().backward()
        meta_opt.step()
        regular_opt.step()
        self._compare_fsdp(fsdp_meta, fsdp_regular)

    @skip_if_lt_x_gpu(2)
    @parametrize("auto_wrap", [True, False])
    def test_nested_model_with_meta_device_reset_params(self, auto_wrap):
        def meta_module_fn():
            return NestedModel(device="meta")

        self._test_nested_model_with_meta_device(
            auto_wrap=auto_wrap, meta_module_fn=meta_module_fn, init_fn=_init_with_reset_params
        )

    @skip_if_lt_x_gpu(2)
    @parametrize("auto_wrap", [True, False])
    def test_nested_model_with_meta_device_default_init(self, auto_wrap):
        def meta_module_fn():
            return NestedModel(device="meta")

        self._test_nested_model_with_meta_device(
            auto_wrap=auto_wrap, meta_module_fn=meta_module_fn,
        )

    @skip_if_lt_x_gpu(2)
    @sandcastle_skip_if(
        not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX"
    )
    @parametrize("auto_wrap", [True, False])
    def test_nested_model_with_torchdistX_default_init(self, auto_wrap):
        def meta_module_fn():
            return deferred_init.deferred_init(NestedModel, device="cuda")

        self._test_nested_model_with_meta_device(
            auto_wrap=auto_wrap, meta_module_fn=meta_module_fn
        )

    @skip_if_lt_x_gpu(2)
    @sandcastle_skip_if(
        not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX"
    )
    @parametrize("auto_wrap", [True, False])
    def test_nested_model_with_torchdistX_init_fn(self, auto_wrap):
        def meta_module_fn():
            return deferred_init.deferred_init(NestedModel, device="cuda")

        self._test_nested_model_with_meta_device(
            auto_wrap=auto_wrap, meta_module_fn=meta_module_fn, init_fn=_init_with_torchdistX,
        )

    def _test_bad_arg(self, meta_module_fn):
        mod = meta_module_fn()
        with self.assertRaisesRegex(ValueError, "to be callable"):
            FSDP(mod, param_init_fn=42)

    @skip_if_lt_x_gpu(2)
    @sandcastle_skip_if(
        not _TORCHDISTX_AVAIL, "Test requires torchdistX: https://github.com/pytorch/torchdistX"
    )
    def test_bad_arg_torchdistx(self):
        def meta_module_fn():
            return deferred_init.deferred_init(NestedModel, "cuda")

        self._test_bad_arg(meta_module_fn)

    @skip_if_lt_x_gpu(2)
    def test_bad_arg_meta(self):
        def meta_module_fn():
            return NestedModel(device="meta")

        self._test_bad_arg(meta_module_fn)


instantiate_parametrized_tests(TestFSDPWithMetaDevice)

if __name__ == "__main__":
    run_tests()