File: test_fsdp_grad_acc.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (292 lines) | stat: -rw-r--r-- 10,968 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# Owner(s): ["oncall: distributed"]

import contextlib
import itertools
import sys
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch import distributed as dist
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    BackwardPrefetch,
    ShardingStrategy,
)
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
    CUDAInitMode,
    FSDPInitMode,
    FSDPTest,
    TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import (
    TEST_WITH_DEV_DBG_ASAN,
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
)

if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


@dataclass
class _GradAccConfig:
    """
    This configures how gradients are accumulated in :meth:`_test_grad_acc`.
    Each instance of this class represents ``num_iters``-many consecutive
    iterations, where the ``no_sync()`` context manager is used or not as given
    by ``use_no_sync``.

    Attributes:
        use_no_sync (bool): Indicates whether to use the ``no_sync()`` context
            manager as the way to accumulate gradients.
        num_iters (int): Number of iterations to accumulate gradients.
    """
    use_no_sync: bool
    num_iters: int

    def __repr__(self) -> str:
        # Override to remove any spaces in the string to appease the internal
        # build's test name parser
        return (
            f"(use_no_sync={self.use_no_sync},"
            f"num_iters={self.num_iters})"
        )


@dataclass
class _GradAccConfigs:
    """
    This wraps a :class:`list` of :class:`_GradAccConfig` instances with the
    sole purpose of overriding :meth:`__repr__` to remove spaces.
    """
    configs: List[_GradAccConfig]

    def __repr__(self) -> str:
        # Override to remove any spaces in the string to appease the internal
        # build's test name parser
        return (
            "[" + ",".join(config.__repr__() for config in self.configs) + "]"
        )


class TestGradAcc(FSDPTest):
    """Tests ``FullyShardedDataParallel``'s gradient accumulation via both its
    ``no_sync()`` context manager and without the context manager."""

    def _test_grad_acc(
        self,
        batch_dim: int,
        configs: List[_GradAccConfig],
        cpu_offload: CPUOffload,
        backward_prefetch: Optional[BackwardPrefetch],
        sharding_strategy: ShardingStrategy,
    ):
        """
        Tests gradient accumulation by comparing a run that trains sequentially
        through some batches while accumulating gradients with a run that
        trains on the concatenation of those batches in a single iteration.

        The last iteration always synchronizes gradients regardless of what is
        specified by the last element of ``configs``.

        Arguments:
            batch_dim (int): Batch dimension in the input tensor to be passed
                into the model for the forward pass.
            configs (List[_GradAccConfig]): :class:`list` of configurations
                specifying how gradients are accumulated; for example, a list
                corresponding to [(False, 2), (True, 2), (False, 2)] indicates
                to accumulate over 2 + 2 + 2 = 6 total iterations, where the
                first two do not use ``no_sync()``, the middle two do use
                ``no_sync()``, and the final two again do not use
                ``no_sync()``.
            cpu_offload (CPUOffload): Configures CPU offloading.
            backward_prefetch (Optional[BackwardPrefetch]): Specifies at which
                point to prefetch the next layer's full parameters during the
                backward pass, if at all.
        """
        # Gradient accumulation outside `no_sync()` is not currently compatible
        # with CPU offloading
        if (
            cpu_offload.offload_params
            and any(not config.use_no_sync for config in configs)
        ):
            return
        old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32
        try:
            # Disable TF32 to prevent floating point drift
            torch.backends.cuda.matmul.allow_tf32 = False

            # Initialize the FSDP model and optimizer
            fsdp_kwargs = {
                "cpu_offload": cpu_offload,
                "backward_prefetch": backward_prefetch,
                "sharding_strategy": sharding_strategy,
            }
            fsdp_model: FSDP = TransformerWithSharedParams.init(
                self.process_group,
                FSDPInitMode.RECURSIVE,
                CUDAInitMode.CUDA_AFTER,
                fsdp_kwargs,
                deterministic=True,
                add_bn=False,  # disable BN since the test uses varying batch sizes
            )
            device = torch.device("cuda")
            optim = torch.optim.SGD(
                fsdp_model.parameters(), lr=0.01, momentum=0.9,
            )

            # Generate the sequence of batches, each containing the same data
            # but permuted
            def permute_tensor(x: torch.Tensor):
                return x.view(-1)[torch.randperm(x.numel())].view_as(x)

            batch: Tuple[torch.Tensor, ...] = \
                fsdp_model.module.get_input(device)
            batches: List[Tuple[torch.Tensor, ...]] = [batch]
            num_iters_to_acc = sum(config.num_iters for config in configs)
            for _ in range(num_iters_to_acc - 1):
                batches.append(tuple(permute_tensor(t) for t in batch))
            for (batch1, batch2) in itertools.combinations(batches, r=2):
                for t1, t2 in zip(batch1, batch2):
                    assert not torch.all(t1 == t2), \
                        "Check the test to make sure that batches are distinct"

            # Concatenate the batches along the given batch dimension
            concat_batch: Tuple[torch.Tensor, ...] = tuple(
                torch.cat(ts, dim=batch_dim) for ts in zip(*batches)
            )

            # Establish reference gradients using the concatenated batch
            fsdp_model.zero_grad()
            output = fsdp_model(*concat_batch)
            ref_loss = fsdp_model.module.get_loss(concat_batch, output)
            ref_loss.backward()
            ref_grads = [
                p.grad.detach().clone() for p in fsdp_model.parameters()
            ]

            # Compute and accumulate the gradients
            fsdp_model.zero_grad()
            losses = []
            batch_idx = 0
            for config in configs:
                sync_context = fsdp_model.no_sync() if config.use_no_sync \
                    else contextlib.suppress()
                with sync_context:
                    for _ in range(config.num_iters):
                        if batch_idx == num_iters_to_acc - 1:
                            break  # always sync on the last iteration
                        batch = batches[batch_idx]
                        batch_idx += 1
                        output = fsdp_model(*batch)
                        loss = fsdp_model.module.get_loss(batch, output)
                        loss.backward()
                        losses.append(loss)
            output = fsdp_model(*batches[-1])
            loss = fsdp_model.module.get_loss(batches[-1], output)
            loss.backward()
            losses.append(loss)
            acc_loss = sum(losses)
            acc_grads = [
                p.grad.detach().clone() for p in fsdp_model.parameters()
            ]

            # Compare the losses and gradients
            torch.testing.assert_close(ref_loss, acc_loss)
            self.assertEqual(len(ref_grads), len(acc_grads))
            for ref_grad, acc_grad in zip(ref_grads, acc_grads):
                self.assertEqual(ref_grad.device, acc_grad.device)
                self.assertEqual(ref_grad.size(), acc_grad.size())
                self.assertEqual(ref_grad.dtype, acc_grad.dtype)
                torch.testing.assert_close(ref_grad, acc_grad)

            # Check that the optimizer step does not error
            optim.step()
        finally:
            torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32

    def _get_subtest_config(self) -> Dict[str, List[Any]]:
        """Returns a subtest configuration that subtests prefetching."""
        return {
            "backward_prefetch": [
                None,
                BackwardPrefetch.BACKWARD_PRE,
                BackwardPrefetch.BACKWARD_POST,
            ]
        }

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "configs",
        [
            _GradAccConfigs([
                _GradAccConfig(use_no_sync=True, num_iters=3),
                _GradAccConfig(use_no_sync=False, num_iters=3),
                _GradAccConfig(use_no_sync=True, num_iters=3),
            ]),
            _GradAccConfigs([
                _GradAccConfig(use_no_sync=False, num_iters=3),
                _GradAccConfig(use_no_sync=True, num_iters=3),
                _GradAccConfig(use_no_sync=False, num_iters=3),
            ]),
        ]
    )
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=False), CPUOffload(offload_params=True)],
    )
    @parametrize(
        "sharding_strategy",
        [
            ShardingStrategy.FULL_SHARD,
            ShardingStrategy.SHARD_GRAD_OP,
            ShardingStrategy.NO_SHARD,
        ]
    )
    def test_grad_acc(
        self,
        configs: _GradAccConfigs,
        cpu_offload: CPUOffload,
        sharding_strategy: ShardingStrategy,
    ):
        """
        Tests gradient accumulation.

        This exercises gradient accumulation inside and outside the
        ``no_sync()`` context manager, in particular by interleaving the two.
        It tests both interleaving starting with (and ending with, resp.)
        inside versus outside ``no_sync()`` to ensure that initial conditions
        (and final conditions, resp.) do not affect the correctness. This test
        also checks for compatibility with the CPU offload and backward
        prefetch options.

        NOTE: Gradient accumulation without using the ``no_sync()`` context
        manager is not currently compatible with CPU offloading, so those tests
        are vacuous.
        """
        self.run_subtests(
            self._get_subtest_config(),
            self._test_grad_acc,
            batch_dim=1,
            configs=configs.configs,
            cpu_offload=cpu_offload,
            sharding_strategy=sharding_strategy,
        )


instantiate_parametrized_tests(TestGradAcc)

if __name__ == "__main__":
    run_tests()