File: test_fsdp_misc.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (488 lines) | stat: -rw-r--r-- 18,388 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
# Owner(s): ["oncall: distributed"]

from copy import deepcopy
import functools
import sys
from collections import namedtuple
from contextlib import suppress

import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.fsdp import FlatParameter
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy, CPUOffload
from torch.distributed.fsdp.wrap import (
    always_wrap_policy,
    transformer_auto_wrap_policy,
)
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
    CUDAInitMode,
    FSDPInitMode,
    FSDPTest,
    NestedWrappedModule,
    TransformerWithSharedParams,
    _assert_module_states,
)
from torch.testing._internal.common_utils import (
    TEST_WITH_DEV_DBG_ASAN,
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
)

if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class TestFSDPMisc(FSDPTest):
    @property
    def world_size(self):
        return 2

    @property
    def process_group(self):
        return dist.distributed_c10d._get_default_group()

    @skip_if_lt_x_gpu(2)
    def test_fsdp_namedtuple(self):
        # Ensure namedtuple support, preventing issues such as
        # https://github.com/pytorch/pytorch/issues/83053
        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = nn.Linear(100, 100)

            def forward(self, x):
                return x

        m = MyModule().cuda()
        m = FSDP(m)
        t = torch.ones(1, device="cuda", requires_grad=True)

        MyOutputType = namedtuple(
            "MyOutputType",
            ["a", "b", "c", "d"],
            defaults=(t, t, t, t)
        )

        inp = MyOutputType()
        out = m(inp)
        # Ensure hooks are registered
        for x in out:
            self.assertNotEqual([], list(x._backward_hooks.values()))

        # TODO: we should check backward() and param is resharded
        # as well, but this is blocked by
        # https://github.com/pytorch/pytorch/issues/83107 and
        # https://github.com/pytorch/pytorch/issues/83129

    @skip_if_lt_x_gpu(2)
    def test_fsdp_not_all_outputs_used_in_loss(self):

        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin1 = nn.Linear(4, 4)
                self.lin2 = nn.Linear(4, 4)

            def forward(self, x):
                a = self.lin1(x)
                b = self.lin2(x)
                return (a, b)

        def _check_resharded(fsdp_module):
            for handle in fsdp_module._handles:
                param = handle.flat_param
                if handle.uses_sharded_strategy:
                    full_param = param._full_param_padded
                    self.assertEqual(full_param.storage().size(), 0)

                self.assertEqual(
                    param.data_ptr(),
                    param._local_shard.data_ptr()
                )

        def _check_equal(local, fsdp):
            with FSDP.summon_full_params(fsdp):
                for p1, p2 in zip(fsdp.parameters(), local.parameters()):
                    torch.testing.assert_allclose(p1, p2)

        for sharding_strategy in [
            ShardingStrategy.FULL_SHARD,
            ShardingStrategy.SHARD_GRAD_OP,
            ShardingStrategy.NO_SHARD
        ]:
            with self.subTest(sharding_strategy=sharding_strategy):
                fsdp_ctor = functools.partial(FSDP, sharding_strategy=sharding_strategy)
                m = MyModule().cuda()
                m_local = deepcopy(m)
                local_m = m_local
                prev_params = [p.clone() for p in m_local.parameters()]

                m.lin1 = fsdp_ctor(m.lin1)
                m = fsdp_ctor(m)
                _check_equal(m_local, m)

                opt = torch.optim.SGD(m.parameters(), lr=1e-3)
                opt_local = torch.optim.SGD(local_m.parameters(), lr=1e-3)

                for i in range(6):
                    t = torch.ones(4, device="cuda")
                    a, b = m(t)
                    local_a, local_b = local_m(t)
                    if i < 2:
                        # use both params in loss computation. Later,
                        # b will go unused and we check grads are the
                        # same as local training.
                        loss = (a @ b).sum()
                        loss_local = (local_a @ local_b).sum()
                    else:
                        loss = a.sum()
                        loss_local = local_a.sum()

                    loss.backward()
                    loss_local.backward()
                    _check_resharded(m)
                    opt.step()
                    opt_local.step()
                    _check_equal(m_local, m)
                    # Ensure at least some change from previous params, otherwise
                    # above check would be vacuously true.
                    self.assertTrue(
                        any(not torch.equal(p1, p2) for p1, p2 in zip(prev_params, m_local.parameters()))
                    )
                    prev_params = [p.clone() for p in local_m.parameters()]
                    opt.zero_grad()
                    opt_local.zero_grad()

                dist.barrier()


    @skip_if_lt_x_gpu(2)
    @parametrize("use_second_layer", [True, False])
    @parametrize("sharding_strategy", [ShardingStrategy.NO_SHARD, None])
    def test_fsdp_module_no_compute_grad(self, use_second_layer, sharding_strategy):
        # When use_second_layer=True, b is involved in forward computation but does
        # not receive grad in backward. Otherwise, b is not involved in forward
        # computation.
        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.a = nn.Linear(10, 10)
                self.b = nn.Linear(10, 10)

            def forward(self, x, y):
                out1 = self.a(x)
                if use_second_layer:
                    out2 = self.b(y)
                    return out1, out2
                else:
                    return out1

        fsdp = FSDP(
            MyModel().cuda(),
            sharding_strategy=sharding_strategy,
            auto_wrap_policy=always_wrap_policy
        )
        x = torch.randn(10, 10, device='cuda')
        y = torch.randn(10, 10, device='cuda')
        for i in range(4):
            if use_second_layer:
                a, b = fsdp(x, y)
            else:
                a = fsdp(x, y)
            loss = a.sum()
            loss.backward()

            # self.a receives grad, self.b does not
            a_grad = fsdp.module.a._fsdp_wrapped_module.flat_param.grad
            b_grad = fsdp.module.b._fsdp_wrapped_module.flat_param.grad
            self.assertIsNotNone(a_grad)
            self.assertIsNone(b_grad)

    @skip_if_lt_x_gpu(2)
    def test_device_id_auto_wrap(self):
        """Tests that ``auto_wrap_policy`` propagates ``device_id`` to all
        nested FSDP instances."""
        auto_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={TransformerEncoderLayer, TransformerDecoderLayer},
        )
        fsdp_kwargs = {
            "auto_wrap_policy": auto_wrap_policy,
            "device_id": torch.cuda.current_device(),
        }
        fsdp_model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            fsdp_kwargs,
        )
        for fsdp_module in FSDP.fsdp_modules(fsdp_model):
            self.assertEqual(
                fsdp_module.compute_device,
                torch.device("cuda", torch.cuda.current_device()),
            )

    @skip_if_lt_x_gpu(2)
    def test_fsdp_device_id_cpu_offload(self):
        """
        Ensures that even if device_id is specified but we have
        CPU offload, module is on CPU after init.
        """
        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.a = nn.Linear(10, 10)
                self.b = nn.Linear(10, 10)

            def forward(self, x):
                return self.b(self.a(x))

        model = MyModel()

        fsdp = FSDP(
            model,
            auto_wrap_policy=always_wrap_policy,
            cpu_offload=CPUOffload(offload_params=True),
            device_id=torch.cuda.current_device()
        )

        cpu_device = torch.device("cpu")

        for fsdp_unit in FSDP.fsdp_modules(fsdp):
            # This FSDP unit may not directly manage
            # any parameters.
            if len(fsdp_unit.params) > 0:
                fsdp_param = fsdp_unit.params[0]
                self.assertEqual(fsdp_param.device, cpu_device)

    @skip_if_lt_x_gpu(2)
    @parametrize("use_index", [True, False])
    def test_fsdp_device_id(self, use_index):
        """
        Tests the FSDP ``device_id`` argument:
          - Wrapping a CPU module should move the module to the GPU matching
          ``device_id``
          - Wrapping a GPU module already on the GPU matching ``device_id``
          should not raise an error
          - Wrapping a GPU module already on GPU and passing a GPU device
          without specifying a device ID (i.e. ``torch.device("cuda")``) warns
        """
        dev_id = (
            torch.cuda.current_device() if use_index
            else torch.device("cuda", torch.cuda.current_device())
        )

        def _check_device_matches(module, device_id):
            """Checks that the ``FlatParameter``s in ``module`` have device
            matching ``device_id``."""
            devices = {
                p.device for p in module.parameters()
                if isinstance(p, FlatParameter)
            }
            assert len(devices) > 0
            self.assertEqual(1, len(devices))
            found_device = devices.pop()
            if use_index and not isinstance(device_id, torch.device):
                device = torch.device("cuda", device_id)
            else:
                device = device_id
            self.assertEqual(found_device, device)

        # Check that FSDP parameters are moved to `device_id` for a CPU module
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_NEVER,
            fsdp_kwargs={"device_id": dev_id},
        )
        _check_device_matches(nested_wrapped_module, dev_id)
        # Check that specifying `device_id` for a GPU module already on that
        # device does not raise an error
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            fsdp_kwargs={"device_id": dev_id},
        )
        _check_device_matches(nested_wrapped_module, dev_id)
        # Check that passing in `torch.device("cuda")` for a GPU module warns
        regex = "does not have an explicit index"
        context = self.assertWarnsRegex(
            expected_warning=UserWarning, expected_regex=regex
        )
        with context:
            nested_wrapped_module = NestedWrappedModule.init(
                self.process_group,
                FSDPInitMode.RECURSIVE,
                CUDAInitMode.CUDA_BEFORE,
                fsdp_kwargs={"device_id": torch.device("cuda")}
            )
        _check_device_matches(
            nested_wrapped_module,
            torch.device("cuda", torch.cuda.current_device())
        )

    @skip_if_lt_x_gpu(2)
    def test_module_device_mismatches_device_id(self):
        """Tests that specifying a ``device_id`` argument to FSDP for a GPU
        module that does not match the GPU device ID raises an error."""
        context = (
            self.assertRaisesRegex(
                ValueError,
                f"cuda:{self.rank} vs cuda:0"
            ) if self.rank != 0 else suppress()
        )
        with context:
            NestedWrappedModule.init(
                self.process_group,
                FSDPInitMode.RECURSIVE,
                # Move wrapped modules to CUDA before wrapping with FSDP
                cuda_init_mode=CUDAInitMode.CUDA_BEFORE,
                # Should raise error since rank 1 is given `device_id=0` when
                # the model is on cuda:1
                fsdp_kwargs={"device_id": 0},
            )

    @skip_if_lt_x_gpu(2)
    def test_multi_device_not_supported(self):
        """Tests that wrapping a multi-device module (i.e. with submodules on
        both GPU and CPU) with FSDP raises an error."""
        class MultiDeviceModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.a = nn.Linear(1, 1).cuda()
                self.b = nn.Linear(1, 1)

        with self.assertRaisesRegex(
            RuntimeError, "FSDP only supports single device modules"
        ):
            FSDP(MultiDeviceModule())

    @skip_if_lt_x_gpu(2)
    def test_no_params(self):
        """
        Test that device_id and cpu init work if module has no params
        (they are effective noops, but ensure FSDP does not assume module
        has parameters during init)
        """
        # Test CPU
        no_params = nn.ReLU()
        module = FSDP(no_params)
        # Test CUDA
        no_params = nn.ReLU().cuda()
        module = FSDP(no_params)
        # Test CPU + device_id
        no_params = nn.ReLU()
        module = FSDP(no_params, device_id=torch.cuda.current_device())
        # For modules with no params, wrong device_id will raise error about
        # inconsistency between compute_device and device_id, since compute_device
        # is computed as torch.cuda.current_device when there are no params.
        no_params = nn.ReLU().cuda()
        context = (
            self.assertRaisesRegex(
                ValueError,
                f"Inconsistent.*cuda:{self.rank} vs cuda:0"
            )
        ) if self.rank != 0 else suppress()
        with context:
            module = FSDP(no_params, device_id=0)

    @skip_if_lt_x_gpu(2)
    def test_fsdp_cpu_init_stays_on_cpu(self):
        """Tests that passing a CPU module to FSDP preserves that the wrapped
        module is on CPU after FSDP initialization, albeit after loging a
        warning, and that FSDP moves CPU input to GPU before the forward."""
        torch.cuda.set_device(self.rank)
        regex = "Module is put on CPU"
        context = self.assertWarnsRegex(
            expected_warning=UserWarning, expected_regex=regex
        )
        with context:
            nested_wrapped_module = NestedWrappedModule.init(
                self.process_group,
                FSDPInitMode.RECURSIVE,
                CUDAInitMode.CUDA_NEVER,
            )
            fsdp_model = FSDP(nested_wrapped_module, self.process_group)
        devices = {p.device for p in fsdp_model.parameters()}
        self.assertEqual(1, len(devices))
        self.assertEqual(torch.device("cpu"), devices.pop())
        fsdp_model = fsdp_model.cuda()
        # Ensure fwd + backward can be performed after moving to CUDA.
        # CPU input also tests that input is correctly moved to appropriate
        # CUDA device.
        inp = fsdp_model.module.get_input(device=torch.device("cpu"))
        fsdp_model(*inp).sum().backward()

    @skip_if_lt_x_gpu(2)
    def test_cpu_init_with_sync_module_states(self):
        """Tests that passing ``sync_module_states=True`` raises an error for
        a CPU module since the synchronization requires GPU communication,
        while additionally passing ``device_id`` does not raise an error."""
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_NEVER,
        )
        with self.assertRaisesRegex(
            ValueError,
            "Module has CPU parameters, but sync_module_states=True is specified."
        ):
            FSDP(nested_wrapped_module, self.process_group, sync_module_states=True)

        # Specifying device_id with sync_module_states=True works.
        FSDP(
            nested_wrapped_module,
            self.process_group,
            device_id=torch.cuda.current_device(),
            sync_module_states=True,
        )

    @skip_if_lt_x_gpu(2)
    def test_fsdp_same_model_across_ranks(self):
        """
        FSDP broadcasts model from rank 0 to ensure it starts off with the same
        values.
        """
        class MyModel(nn.Module):
            def __init__(self, rank):
                super().__init__()
                # Seed via rank to make model different across ranks
                torch.manual_seed(rank)
                torch.cuda.manual_seed(rank)
                self.lin = nn.Linear(10, 10, bias=False)
                self.register_buffer("buffer", torch.ones(1) * rank)

        m = MyModel(self.rank).cuda()
        _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual)
        # Passing sync_module_states into FSDP makes model the same during init.
        fsdp = FSDP(m, sync_module_states=True)
        with fsdp.summon_full_params(fsdp):
            _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)

        # sync_module_states also works with CPU module with device_id passed in
        m = MyModel(self.rank)
        _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual)
        # Passing sync_module_states into FSDP makes model the same during init.
        fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
        with fsdp.summon_full_params(fsdp):
            _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)


instantiate_parametrized_tests(TestFSDPMisc)

if __name__ == "__main__":
    run_tests()