File: test_fsdp_unshard_params.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (743 lines) | stat: -rw-r--r-- 29,442 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
# Owner(s): ["oncall: distributed"]
import contextlib
import itertools
import math
import sys
from typing import Any, Dict, List, Optional, Union

import torch
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
from torch import distributed as dist
from torch.distributed.fsdp import (
    CPUOffload,
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.fsdp._common_utils import clean_tensor_name
from torch.distributed.fsdp._flat_param import FlatParameter
from torch.distributed.fsdp.fully_sharded_data_parallel import FLAT_PARAM
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
    DEVICEInitMode,
    FSDPInitMode,
    FSDPTest,
    get_devtype,
    NestedWrappedModule,
    TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN


device_type = torch.device(get_devtype())

if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class TestUnshardParamsBase(FSDPTest):
    """
    This contains any methods common to both the sharded and non-sharded cases.
    """

    def _test_unshard_params_writeback(
        self,
        writeback: bool,
        check_outer: bool,
        **fsdp_kwargs: Dict[str, Any],
    ):
        model = nn.Sequential(
            nn.Linear(5, 5, bias=False, device=device_type.type),
            nn.Linear(5, 3, bias=False, device=device_type.type),
        )
        model[0] = FSDP(model[0], **fsdp_kwargs)
        model = FSDP(model, **fsdp_kwargs)
        uses_sharded_strategy = model.sharding_strategy != ShardingStrategy.NO_SHARD
        offloading_params = model.cpu_offload.offload_params

        # Assumes depth-first `.parameters()`
        outer_param: Union[FlatParameter, nn.Parameter] = next(model.parameters())
        inner_param: Union[FlatParameter, nn.Parameter] = next(model[0].parameters())
        param_to_check = outer_param if check_outer else inner_param

        # Write a known value to all elements of the *sharded* parameter or
        # `FlatParameter` to check
        with torch.no_grad():
            param_to_check.zero_()
            param_to_check += self.rank + 2
        # Zero the *unsharded* parameters
        with FSDP.summon_full_params(model, writeback=writeback), torch.no_grad():
            for param in model.parameters():
                param.zero_()

        # Check the 0th singleton element of the sharded parameter to see if
        # the zeroing from inside the context persists
        param_elem_to_check = param_to_check[0]
        if param_elem_to_check.numel() > 1:
            # For `use_orig_params=True` and `NO_SHARD`, the parameter
            # preserves the original 2D shape, so we must access one more time
            param_elem_to_check = param_elem_to_check[0]
        if writeback or (not uses_sharded_strategy and not offloading_params):
            # When FSDP does not use a sharded strategy and is not offloading
            # parameters to CPU, it directly exposes the tensor storage that
            # serves as the unsharded source of truth, so the write is always
            # reflected regardless of `writeback`.
            self.assertEqual(param_elem_to_check, 0)
        else:
            self.assertEqual(param_elem_to_check, self.rank + 2)
        if offloading_params:
            cpu_device = torch.device("cpu")
            for param in model.parameters():
                self.assertEqual(param.device, cpu_device)

    def _get_test_unshard_params_writeback_config(self) -> Dict[str, List[Any]]:
        return {
            "writeback": [True, False],
            "check_outer": [True, False],
            "mixed_precision": [MixedPrecision(param_dtype=torch.float16), None],
            "cpu_offload": [
                CPUOffload(offload_params=False),
                CPUOffload(offload_params=True),
            ],
            "use_orig_params": [True, False],
        }

    def _test_unshard_params_param_data(
        self,
        rank0_only: bool,
        offload_to_cpu: bool,
        cpu_offload: CPUOffload,
        mixed_precision: Optional[MixedPrecision],
        use_orig_params: bool,
    ):
        local_model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            DEVICEInitMode.DEVICE_BEFORE,
            fsdp_kwargs={"device_id": device_type.type},
            deterministic=True,
        )
        # Apply FSDP such that the root module does not have FSDP applied,
        # while there are multiple FSDP root submodules (as proven later)
        fsdp_model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            DEVICEInitMode.DEVICE_BEFORE,
            fsdp_kwargs={
                "cpu_offload": cpu_offload,
                "mixed_precision": mixed_precision,
                "use_orig_params": use_orig_params,
            },
            deterministic=True,
        )
        self.assertFalse(isinstance(fsdp_model, FSDP))
        # Hard code the following names because getting them is non-trivial
        non_fsdp_managed_param_names = {
            "module.0.weight",
            "module.0.bias",
            "module.3.weight",
            "module.3.bias",
        }

        with FSDP.summon_full_params(
            fsdp_model,
            rank0_only=rank0_only,
            writeback=not rank0_only,
            offload_to_cpu=offload_to_cpu,
        ):
            if not rank0_only or self.rank == 0:
                for p1, (n2, p2) in zip(
                    local_model.parameters(), fsdp_model.named_parameters()
                ):
                    self.assertEqual(p1.shape, p2.shape)
                    if (
                        offload_to_cpu
                        and clean_tensor_name(n2) not in non_fsdp_managed_param_names
                    ):
                        self.assertEqual(torch.device("cpu"), p2.device)
                    else:
                        self.assertEqual(p1.device, p2.device)
                    self.assertEqual(
                        p1.dtype, p2.dtype
                    )  # even if FSDP uses mixed precision
                    self.assertEqual(p1, p2)
                    self.assertTrue(isinstance(p2, nn.Parameter))
            else:
                # Check that each `FlatParameter` has the sharded size as a
                # proxy for it being resharded
                for handle in traversal_utils._get_fsdp_handles(fsdp_model):
                    if handle.uses_sharded_strategy:
                        self.assertEqual(
                            handle.flat_param.shape, handle.flat_param._sharded_size
                        )
                    else:
                        self.assertEqual(
                            handle.flat_param.shape,
                            handle.flat_param._unpadded_unsharded_size,
                        )

        # Prove the number of FSDP roots after lazy initialization
        num_fsdp_roots = 0
        for fsdp_state in traversal_utils._get_fsdp_states(fsdp_model):
            num_fsdp_roots += fsdp_state._is_root
        self.assertGreater(num_fsdp_roots, 1)

    def _get_test_unshard_params_param_data_config(self) -> Dict[str, List[Any]]:
        return {
            "rank0_only": [False, True],
            "offload_to_cpu": [False, True],
            "cpu_offload": [
                CPUOffload(offload_params=False),
                CPUOffload(offload_params=True),
            ],
            "mixed_precision": [MixedPrecision(param_dtype=torch.float16), None],
            "use_orig_params": [True, False],
        }


class TestUnshardParams(TestUnshardParamsBase):
    @property
    def world_size(self) -> int:
        return 2

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_writeback(self):
        """Tests the ``writeback`` argument (using default for all others)."""
        self.run_subtests(
            self._get_test_unshard_params_writeback_config(),
            self._test_unshard_params_writeback,
        )

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_param_data(self):
        """
        Tests that parameters are exposed correctly for ``recurse=True`` and
        all other argument configs for a non-FSDP root module.
        """
        self.run_subtests(
            self._get_test_unshard_params_param_data_config(),
            self._test_unshard_params_param_data,
        )

    @skip_if_lt_x_gpu(2)
    def test_unshard_singleton_param_writeback(self):
        """
        Tests ``writeback=True`` for a singleton parameter, which includes
        testing that writing to padding does not persist.
        NOTE: This method depends on FSDP internals.
        """
        model = FSDP(nn.Linear(1, 1, bias=False, device=device_type.type))
        flat_param = model._handle.flat_param
        self.assertEqual(1, flat_param.numel())
        # Write a known value to the *sharded* `FlatParameter`
        with torch.no_grad():
            # For nonzero ranks, this write is to padding
            flat_param[0] = self.rank + 2
        with FSDP.summon_full_params(model, writeback=True):
            self.assertEqual(1, flat_param.numel())
            with torch.no_grad():
                flat_param.zero_()
        # NOTE: This checks that writes to padding did not persist, which is
        # *not* strictly required for correctness.
        if self.rank == 0:  # did not write to padding
            self.assertEqual(0, flat_param[0])
        else:  # wrote to padding
            self.assertEqual(self.rank + 2, flat_param[0])

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_respects_reshard(self):
        """
        Tests that unsharding parameters respects the expected reshard behavior
        between forward and backward as well as after backward.

        For mixed precision, we should *not* respect the reshard behavior
        because the ``summon_full_params()`` forces full precision, which uses
        a different all-gather tensor than the one already in memory and will
        not persist any modifications correctly.
        """
        self.run_subtests(
            {
                "rank0_only": [False, True],
                "offload_to_cpu": [False, True],
                "mixed_precision": [MixedPrecision(param_dtype=torch.float16), None],
                "use_orig_params": [False, True],
            },
            self._test_unshard_params_respects_reshard,
        )

    def _test_unshard_params_respects_reshard(
        self,
        rank0_only: bool,
        offload_to_cpu: bool,
        mixed_precision: Optional[MixedPrecision],
        use_orig_params: bool,
    ):
        """NOTE: This method depends on FSDP internals."""
        fsdp_kwargs = {
            "mixed_precision": mixed_precision,
            "use_orig_params": use_orig_params,
        }
        model = FSDP(
            nn.Sequential(
                FSDP(
                    nn.Linear(5, 5, bias=False, device=device_type.type), **fsdp_kwargs
                ),
                nn.Linear(5, 3, bias=False, device=device_type.type),
            ),
            **fsdp_kwargs,
        )
        outer_flat_param = model._handle.flat_param
        inner_flat_param = model.module[0]._handle.flat_param
        # NOTE: This assumes uniform sharding with padding across ranks.
        expected_outer_flat_param_unsharded_numel = (
            outer_flat_param.numel() * self.world_size
        )

        def _get_unsharded_storage_size(flat_param: FlatParameter):
            return flat_param._full_param_padded.storage().size()

        # Validate the expected behavior: the root does not reshard after
        # forward; the non-root reshards after forward; and both reshard after
        # backward
        output = model(torch.zeros(5, device=device_type.type))
        self.assertEqual(
            expected_outer_flat_param_unsharded_numel,
            _get_unsharded_storage_size(outer_flat_param),
        )
        self.assertEqual(0, _get_unsharded_storage_size(inner_flat_param))
        output.sum().backward()
        self.assertEqual(0, _get_unsharded_storage_size(outer_flat_param))
        self.assertEqual(0, _get_unsharded_storage_size(inner_flat_param))

        # Check that with parameter unsharding in between forward and backward
        # as well as after backward, the reshard behavior matches
        output = model(torch.zeros(5, device=device_type.type))
        with FSDP.summon_full_params(
            model,
            rank0_only=rank0_only,
            writeback=not rank0_only,
            offload_to_cpu=offload_to_cpu,
        ):
            pass
        if mixed_precision is not None:
            # After forcing full precision, we must invalidate the existing
            # unsharded low-precision flat parameter since it will not persist
            # changes from the `summon_full_params()` context, so we cannot
            # respect the reshard behavior
            expected_outer_flat_param_unsharded_numel = 0
        self.assertEqual(
            expected_outer_flat_param_unsharded_numel,
            _get_unsharded_storage_size(outer_flat_param),
        )
        self.assertEqual(0, _get_unsharded_storage_size(inner_flat_param))
        output.sum().backward()
        with FSDP.summon_full_params(
            model,
            rank0_only=rank0_only,
            writeback=not rank0_only,
            offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(0, _get_unsharded_storage_size(outer_flat_param))
        self.assertEqual(0, _get_unsharded_storage_size(inner_flat_param))

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_recurse(self):
        """Tests the ``recurse`` argument (using default for all others)."""
        self.run_subtests(
            {
                "recurse": [False, True],
                "unshard_outer": [False, True],
                "mixed_precision": [MixedPrecision(param_dtype=torch.float16), None],
                "use_orig_params": [False, True],
            },
            self._test_unshard_params_recurse,
        )

    def _test_unshard_params_recurse(
        self,
        recurse: bool,
        unshard_outer: bool,
        mixed_precision: Optional[MixedPrecision],
        use_orig_params: bool,
    ):
        """NOTE: This method depends on FSDP internals."""
        fsdp_kwargs = {
            "mixed_precision": mixed_precision,
            "use_orig_params": use_orig_params,
        }
        model = FSDP(
            nn.Sequential(
                FSDP(
                    nn.Linear(5, 5, bias=False, device=device_type.type), **fsdp_kwargs
                ),
                nn.Linear(5, 3, bias=False, device=device_type.type),
            ),
            **fsdp_kwargs,
        )
        # Hard code the numel values based on the model
        unsharded_inner_numel = 5 * 5
        unsharded_outer_numel = 5 * 3
        if use_orig_params:
            # Account for unsharded padding: since each `FlatParameter` only
            # has one original parameter, we only need to pad for divisibility
            # by world size and not address alignment
            if unsharded_inner_numel % self.world_size:
                unsharded_inner_numel += self.world_size - (
                    unsharded_inner_numel % self.world_size
                )
            if unsharded_outer_numel % self.world_size:
                unsharded_outer_numel += self.world_size - (
                    unsharded_outer_numel % self.world_size
                )
        # Round up the sharded numel to account for padding
        sharded_inner_numel = int(math.ceil(unsharded_inner_numel / self.world_size))
        sharded_outer_numel = int(math.ceil(unsharded_outer_numel / self.world_size))
        inner_flat_param = model.module[0]._handle.flat_param
        outer_flat_param = model._handle.flat_param
        self.assertEqual(sharded_inner_numel, inner_flat_param.numel())
        self.assertEqual(sharded_outer_numel, outer_flat_param.numel())
        expected_outer_numel = (
            unsharded_outer_numel if unshard_outer else sharded_outer_numel
        )
        expected_inner_numel = (
            unsharded_inner_numel
            if recurse or not unshard_outer
            else sharded_inner_numel
        )
        module_to_unshard = model if unshard_outer else model[0]
        with FSDP.summon_full_params(module_to_unshard, recurse=recurse):
            self.assertEqual(expected_outer_numel, outer_flat_param.numel())
            self.assertEqual(expected_inner_numel, inner_flat_param.numel())

    @skip_if_lt_x_gpu(2)
    def test_named_parameters_and_buffers(self):
        """
        Tests that ``named_parameters()`` and ``named_buffers()`` for a
        top-level FSDP-wrapped model matches their behavior for the equivalent
        non-wrapped module.
        """
        self.run_subtests(
            {"prefix": ["", "test_prefix"], "recurse": [False, True]},
            self._test_named_parameters_and_buffers,
        )

    def _test_named_parameters_and_buffers(self, prefix: str, recurse: bool):
        model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            DEVICEInitMode.DEVICE_BEFORE,
            deterministic=True,
        )
        model.buffer = nn.Buffer(torch.ones(1))
        # Wrap the top-level with FSDP since `named_parameters()` and
        # `named_buffers` will contain FSDP prefixes if called on a non-FSDP
        # root module
        fsdp_model = FSDP(
            NestedWrappedModule.init(
                self.process_group,
                FSDPInitMode.NO_FSDP,
                DEVICEInitMode.DEVICE_BEFORE,
                deterministic=True,
            ),
            self.process_group,
        )
        fsdp_model.buffer = nn.Buffer(torch.ones(1))
        with FSDP.summon_full_params(fsdp_model):
            for call in ["named_parameters", "named_buffers"]:
                for (n1, p1), (n2, p2) in itertools.zip_longest(
                    getattr(fsdp_model, call)(prefix=prefix, recurse=recurse),
                    getattr(model, call)(prefix=prefix, recurse=recurse),
                ):
                    self.assertEqual(n1, n2)
                    self.assertEqual(p1, p2)

    @skip_if_lt_x_gpu(2)
    def test_with_grads_core(self):
        """
        Tests the core usage of``with_grads=True`` by comparing against DDP as
        the unsharded equivalent.
        """
        self.run_subtests(
            {
                "writeback": [False, True],
                "offload_to_cpu": [False, True],
                "sharding_strategy": [
                    ShardingStrategy.FULL_SHARD,
                    ShardingStrategy.SHARD_GRAD_OP,
                    ShardingStrategy.NO_SHARD,
                ],
                "use_orig_params": [True],
            },
            self._test_with_grads_core,
        )

    def _test_with_grads_core(
        self,
        writeback: bool,
        offload_to_cpu: bool,
        sharding_strategy: ShardingStrategy,
        use_orig_params: bool,
    ):
        def _check_grads(
            ddp_model: DDP,
            fsdp_model: FSDP,
            old_fsdp_grads: Optional[List[torch.Tensor]],
        ):
            """
            Checks that writes to the FSDP parameters' gradients persist or do
            not persist depending on ``writeback`` and the sharding strategy.
            The DDP model is used for checking gradient parity to ensure that
            FDSP all-gathers the correct gradient values.
            """
            WRITEBACK_FACTOR = 2
            with FSDP.summon_full_params(
                fsdp_model,
                writeback=writeback,
                offload_to_cpu=offload_to_cpu,
                with_grads=True,
            ):
                for (n1, p1), (n2, p2) in zip(
                    ddp_model.module.named_parameters(),
                    fsdp_model.named_parameters(),
                ):
                    self.assertEqual(n1, clean_tensor_name(n2))
                    assert p1.grad is not None
                    torch.testing.assert_close(p1.grad, p2.grad)
                    # Ensure that the tensor is not all zeros, which would
                    # mean that the multiplication is vacuous
                    assert torch.count_nonzero(p2.grad) > 0
                    p2.grad *= WRITEBACK_FACTOR
            new_fsdp_grads = [
                param.grad
                for param in fsdp_model.parameters()
                if param.grad is not None
            ]
            writeback_persists = writeback or (
                sharding_strategy == ShardingStrategy.NO_SHARD and not offload_to_cpu
            )
            for old_grad, new_grad in zip(old_fsdp_grads, new_fsdp_grads):
                if writeback_persists:
                    torch.testing.assert_close(old_grad * WRITEBACK_FACTOR, new_grad)
                else:
                    torch.testing.assert_close(old_grad, new_grad)
            if writeback_persists:
                # Modify the DDP gradients in the same way for parity
                for param in ddp_model.parameters():
                    param.grad *= WRITEBACK_FACTOR

        def _get_error_context(is_supported: bool):
            return (
                contextlib.nullcontext()
                if is_supported
                else self.assertRaises(NotImplementedError)
            )  # some configs are not implemented yet

        def _get_fsdp_grads(fsdp_model: FSDP, is_supported: bool):
            if is_supported:
                return [
                    param.grad.clone()
                    for param in fsdp_model.parameters()
                    if param.grad is not None
                ]
            return None  # unused

        is_supported = use_orig_params and not offload_to_cpu
        model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            DEVICEInitMode.DEVICE_BEFORE,
            deterministic=True,
        )
        ddp_model = DDP(model, device_ids=[device_type])
        fsdp_model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            DEVICEInitMode.DEVICE_BEFORE,
            deterministic=True,
            fsdp_kwargs={
                "use_orig_params": use_orig_params,
                "sharding_strategy": sharding_strategy,
                "device_id": device_type.type,
            },
        )
        with FSDP.summon_full_params(fsdp_model):
            for p1, p2 in zip(ddp_model.module.parameters(), fsdp_model.parameters()):
                assert torch.all(torch.isclose(p1, p2))

        # Check calling after backward
        inp = fsdp_model.get_input(torch.device(device_type))
        ddp_out = ddp_model(*inp)
        fsdp_out = fsdp_model(*inp)
        ddp_out.sum().backward()
        fsdp_out.sum().backward()
        old_fsdp_grads = _get_fsdp_grads(fsdp_model, is_supported)
        with _get_error_context(is_supported):
            _check_grads(ddp_model, fsdp_model, old_fsdp_grads)

        # Check calling between forward and backward
        inp = fsdp_model.get_input(torch.device(device_type))
        ddp_out = ddp_model(*inp)
        fsdp_out = fsdp_model(*inp)
        old_fsdp_grads = _get_fsdp_grads(fsdp_model, is_supported)
        with _get_error_context(is_supported):
            _check_grads(ddp_model, fsdp_model, old_fsdp_grads)

    @skip_if_lt_x_gpu(2)
    def test_with_grads_none_grads(self):
        """
        Tests that if all ranks' ``FlatParameter`` has ``None`` gradient, then
        each original parameter sees ``None`` gradient as well.
        """
        self.run_subtests(
            {
                "sharding_strategy": [
                    ShardingStrategy.FULL_SHARD,
                    ShardingStrategy.SHARD_GRAD_OP,
                    ShardingStrategy.NO_SHARD,
                ]
            },
            self._test_with_grads_none_grads,
        )

    def _test_with_grads_none_grads(self, sharding_strategy: ShardingStrategy):
        fsdp_model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            DEVICEInitMode.DEVICE_BEFORE,
            deterministic=True,
            fsdp_kwargs={
                "use_orig_params": True,
                "sharding_strategy": sharding_strategy,
                "device_id": device_type.type,
            },
        )
        for fsdp_module in FSDP.fsdp_modules(fsdp_model):
            if fsdp_module._handle:
                assert fsdp_module._handle.flat_param.grad is None
        with FSDP.summon_full_params(fsdp_model, with_grads=True):
            for param in fsdp_model.parameters():
                self.assertTrue(param.grad is None)

    @skip_if_lt_x_gpu(2)
    def test_unshard_submodule(self):
        model = nn.Sequential(
            nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16)),
            nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 16)),
        ).to(device_type.type)
        model = FSDP(model, auto_wrap_policy=ModuleWrapPolicy((nn.Sequential,)))
        with FSDP.summon_full_params(model[0]):
            # Check that the summoned module does not have its flat parameter
            for param_name, param in model[0].named_parameters():
                self.assertFalse(FLAT_PARAM in param_name)
            self.assertGreater(len(list(model[0].parameters())), 1)


class TestUnshardParamsNoShard(TestUnshardParamsBase):
    @property
    def world_size(self) -> int:
        return 1

    @skip_if_lt_x_gpu(1)
    def test_unshard_params_writeback_no_shard(self):
        """Tests the ``writeback`` argument (using default for all others)."""
        self.run_subtests(
            self._get_test_unshard_params_writeback_config(),
            self._test_unshard_params_writeback,
        )

    @skip_if_lt_x_gpu(1)
    def test_unshard_params_param_data_no_shard(self):
        """
        Tests that parameters are exposed correctly for ``recurse=True`` and
        all other argument configs for a non-FSDP root module.
        """
        config = self._get_test_unshard_params_param_data_config()
        # TODO: `offload_to_cpu=True` with `NO_SHARD` is not supported yet. See
        # `test_offload_to_cpu_no_shard_raises()`.
        config["offload_to_cpu"] = [False]
        self.run_subtests(
            config,
            self._test_unshard_params_param_data,
        )


class TestUnshardParamsErrors(TestUnshardParamsBase):
    @property
    def world_size(self) -> int:
        return 2

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_from_forward_raises(self):
        class MyModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = nn.Parameter(torch.zeros(5))

            def forward(self, fsdp_module):
                with fsdp_module.summon_full_params(fsdp_module):
                    pass

        model = FSDP(MyModule()).to(device_type.type)
        with self.assertRaisesRegex(
            AssertionError, "Cannot manually unshard parameters during forward/backward"
        ):
            model(model)

    @skip_if_lt_x_gpu(2)
    def test_unshard_params_from_backward_raises(self):
        model = FSDP(nn.Linear(2, 1, device=device_type.type))
        output = model(torch.ones(2, device=device_type.type))

        def invalid_backward_hook(*args, **kwargs):
            with FSDP.summon_full_params(model):
                pass

        self.assertTrue(output.requires_grad)
        output.register_hook(invalid_backward_hook)
        with self.assertRaisesRegex(
            AssertionError, "Cannot manually unshard parameters during forward/backward"
        ):
            output.backward()

    @skip_if_lt_x_gpu(2)
    def test_rank0_only_with_writeback_raises(self):
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            DEVICEInitMode.DEVICE_BEFORE,
        )
        with self.assertRaisesRegex(NotImplementedError, "is not supported"):
            with FSDP.summon_full_params(
                nested_wrapped_module, rank0_only=True, writeback=True
            ):
                pass

    @skip_if_lt_x_gpu(2)
    def test_offload_to_cpu_no_shard_raises(self):
        nested_wrapped_module = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            DEVICEInitMode.DEVICE_BEFORE,
            {"sharding_strategy": ShardingStrategy.NO_SHARD},
        )
        with self.assertRaisesRegex(NotImplementedError, "is not supported"):
            with FSDP.summon_full_params(
                nested_wrapped_module, rank0_only=True, writeback=True
            ):
                pass


if __name__ == "__main__":
    run_tests()