File: test_fsdp_sharded_grad_scaler.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (350 lines) | stat: -rw-r--r-- 13,608 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# Owner(s): ["oncall: distributed"]

import copy
import functools
import itertools
import sys
import unittest
from typing import List, Optional

import torch
from torch import distributed as dist
from torch.cuda.amp.common import amp_definitely_not_available
from torch.distributed.fsdp import CPUOffload, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
    DEVICEInitMode,
    DummyProcessGroup,
    FSDPInitMode,
    FSDPTest,
    NestedWrappedModule,
    NonUniformReqGradNWM,
    subtest_name,
    TransformerWithSharedParams,
)
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    TEST_WITH_DEV_DBG_ASAN,
    TestCase,
)


if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)


if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


params = "cpu_offload,sharding_strategy,mixed_precision,use_orig_params"
cpu_offload_config = [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
sharding_strategy_config = [ShardingStrategy.SHARD_GRAD_OP, None]
mixed_precision = ["enable_mixed_precision", None]
use_orig_params = ["enable_use_orig_params", None]

configs = list(
    itertools.product(
        cpu_offload_config, sharding_strategy_config, mixed_precision, use_orig_params
    )
)
test_name_mapping = {
    str(CPUOffload(offload_params=True)): "offload_true",
    str(CPUOffload(offload_params=False)): "offload_false",
    str(ShardingStrategy.SHARD_GRAD_OP): "shard_grad_op",
    "enable_mixed_precision": "mixed_precision",
    "enable_use_orig_params": "use_orig_params",
}

subtest_name = functools.partial(subtest_name, test_name_mapping)


class TestShardGradScaler(TestCase):
    @unittest.skipIf(
        amp_definitely_not_available(), "no supported device (cuda, xla) found"
    )
    def test_grad_scaling(self):
        pg = DummyProcessGroup(0, 1)
        scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True)
        t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cpu")
        t1 = torch.full((1,), 8.0, dtype=torch.float32, device="cpu")
        outputs = [t1.clone(), (t0.clone(), t1.clone()), [t0.clone(), t1.clone()]]
        outputs = scaler.scale(outputs)
        self.assertTrue(
            outputs[0] == 16.0 and outputs[1][0] == 8.0 and outputs[1][1] == 16.0
        )
        self.assertTrue(outputs[2][0] == 8.0 and outputs[2][1] == 16.0)
        self.assertTrue(scaler._scale.device == t1.device)

    @unittest.skipIf(
        amp_definitely_not_available(), "no supported device (cuda, xla) found"
    )
    def test_scaling_unscaling_sparse(self):
        pg = DummyProcessGroup(0, 1)
        scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True)
        inv_scale = torch.full((1,), 0.5, dtype=torch.float, device="cpu")
        found_inf = torch.full((1,), 0, dtype=torch.float, device="cpu")

        i = torch.tensor([[0, 1, 1], [2, 0, 2]], device="cpu", dtype=torch.int64)
        v = torch.tensor([16.0, 32.0, 64.0], dtype=torch.float, device="cpu")
        s = torch.sparse_coo_tensor(
            i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float
        )

        # unscale sparse tensors
        s1 = s.clone()
        s1.grad = s.clone()
        opt = torch.optim.SGD([s1], lr=1.0)
        found_inf.zero_()
        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf)[s1.device]
        self.assertEqual(found_inf, 0.0)
        self.assertEqual(s1.grad.to_dense(), (s / 2).to_dense())

        # unscale sparse tensor: inf
        v = torch.tensor([16.0, 32.0, float("inf")], dtype=torch.float, device="cpu")
        s1.grad = torch.sparse_coo_tensor(
            i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float
        )
        found_inf.zero_()
        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf)[s1.device]
        self.assertEqual(found_inf, 1.0)

        # unscale sparse tensor: overflow (marked as inf)
        i = torch.tensor([[1, 1, 1], [0, 0, 2]], device="cpu", dtype=torch.int64)
        # coalescing sparse tensor here will cause the value to be Inf
        v = torch.tensor([2**15, 2**15, 1.0], dtype=torch.float16, device="cpu")
        s1 = torch.sparse_coo_tensor(
            i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float16
        )
        s1.grad = s1.clone()
        found_inf.zero_()
        found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf)[s1.device]
        self.assertEqual(found_inf, 1.0)

    @unittest.skipIf(
        amp_definitely_not_available(), "no supported device (cuda, xla) found"
    )
    def test_inf_gradients_skip_optim_step(self):
        pg = DummyProcessGroup(0, 1)
        scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True)
        loss = torch.full((1,), 4.0, dtype=torch.float32, device="cpu")
        t0 = torch.tensor([float("inf")], dtype=torch.float32, device="cpu")
        t0.grad = t0.clone()
        opt = torch.optim.SGD([t0], lr=1.0)
        scaler.scale(loss)
        ret_val = scaler.step(opt)
        self.assertTrue(ret_val is None)


class TestShardedGradScalerParityWithDDP(FSDPTest):
    def _get_init_modes_for_test(self, cpu_offload):
        modes = [DEVICEInitMode.DEVICE_AFTER, DEVICEInitMode.DEVICE_BEFORE]
        # Note that DEVICEInitMode.DEVICE_NEVER works currently only with CPU
        # offload as we explicitly bring the param back to CUDA device. In
        # general, it will not work since we try to all_gather p.data which is
        # on CPU but NCCL only supports GPU.
        if cpu_offload.offload_params:
            modes.append(DEVICEInitMode.DEVICE_NEVER)

        return modes

    @skip_if_lt_x_gpu(2)
    @parametrize(params, configs, subtest_name)
    def test_fsdp_ddp_parity_with_grad_scaler(
        self,
        cpu_offload: CPUOffload,
        sharding_strategy: Optional[ShardingStrategy],
        mixed_precision: Optional[str],
        use_orig_params: Optional[str],
    ):
        init_modes = self._get_init_modes_for_test(cpu_offload)
        mp = (
            MixedPrecision(
                param_dtype=torch.float16,
                reduce_dtype=torch.float16,
                buffer_dtype=torch.float16,
            )
            if mixed_precision is not None
            else None
        )
        # the ``NonUniformReqGradNWM`` model requires we set `init_scale`
        # more conservatively than default to avoid infs with the initial steps
        if use_orig_params == "enable_use_orig_params":
            use_orig = True
            model_cls = NonUniformReqGradNWM
            sharded_grad_scaler_kwargs = {"init_scale": 2.0**11}
        else:
            use_orig = False
            model_cls = NestedWrappedModule  # type: ignore[assignment]
            sharded_grad_scaler_kwargs = None
        for device_init_mode in init_modes:
            self._test_fsdp_parity(
                model_cls,
                FSDPInitMode.RECURSIVE,
                device_init_mode=device_init_mode,
                cpu_offload=cpu_offload,
                sharding_strategy=sharding_strategy,
                mixed_precision=mp,
                enable_sharded_grad_scaler=True,
                use_orig_params=use_orig,
                sharded_grad_scaler_kwargs=sharded_grad_scaler_kwargs,
            )

    def _build_model_and_optim(
        self,
        cpu_offload: CPUOffload = CPUOffload(offload_params=False),
        use_orig_params: bool = False,
    ):
        model = TransformerWithSharedParams.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            DEVICEInitMode.DEVICE_BEFORE,
            deterministic=True,
        )
        ref_model = DDP(
            copy.deepcopy(model),
            device_ids=[self.rank],
        )
        ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
        fsdp_kwargs = {
            "use_orig_params": use_orig_params,
            "cpu_offload": cpu_offload,
            "auto_wrap_policy": ModuleWrapPolicy(
                {
                    TransformerEncoderLayer,
                    TransformerDecoderLayer,
                }
            ),
        }
        model = FSDP(model, **fsdp_kwargs)
        optim = torch.optim.Adam(model.parameters(), lr=1e-2)
        return model, optim, ref_model, ref_optim

    @skip_if_lt_x_gpu(2)
    def test_sharded_grad_scaler_found_inf(self):
        self.run_subtests(
            {
                "use_orig_params": [False, True],
                "cpu_offload": [
                    CPUOffload(offload_params=True),
                    CPUOffload(offload_params=False),
                ],
            },
            self._test_sharded_grad_scaler_found_inf,
        )

    def _test_sharded_grad_scaler_found_inf(
        self,
        use_orig_params: bool,
        cpu_offload: CPUOffload,
    ):
        model, optim, ref_model, ref_optim = self._build_model_and_optim(
            cpu_offload=cpu_offload,
            use_orig_params=use_orig_params,
        )
        grad_scaler = ShardedGradScaler(init_scale=2.0)
        ref_grad_scaler = torch.amp.GradScaler(device="cuda", init_scale=2.0)
        scaled_losses: List[torch.Tensor] = []
        device = torch.device("cuda")
        torch.manual_seed(42 + self.rank + 1)

        for iter in range(10):
            for _model, _optim, _grad_scaler in (
                (ref_model, ref_optim, ref_grad_scaler),
                (model, optim, grad_scaler),
            ):
                module = _model.module
                inp = module.get_input(device)
                _optim.zero_grad()
                output = _model(*inp)
                loss = module.get_loss(inp, output)
                scaled_loss = _grad_scaler.scale(loss)
                scaled_losses.append(scaled_loss)
                scaled_loss.backward()
                orig_params = [
                    param.detach().clone()
                    for param in _model.parameters()
                    if param.grad is not None
                ]
                should_find_inf = iter % 2 == 0
                if should_find_inf and (
                    _model is ref_model or (_model is model and self.rank == 0)
                ):
                    # other ranks should find infs from rank 0
                    # after collectives
                    for param in _model.parameters():
                        if param.grad is None:
                            continue
                        param.grad.fill_(float("inf"))
                        break
                _grad_scaler.step(_optim)
                orig_scale = _grad_scaler.get_scale()
                _grad_scaler.update()
                if should_find_inf:
                    self.assertEqual(
                        _grad_scaler.get_scale(),
                        orig_scale * _grad_scaler.get_backoff_factor(),
                        (
                            f"rank: {self.rank} iter: {iter} expect origin scale {orig_scale} "
                            f"to be backed off by {_grad_scaler.get_backoff_factor()} "
                            f"but got {_grad_scaler.get_scale()}"
                        ),
                    )
                else:
                    self.assertEqual(
                        _grad_scaler.get_scale(),
                        orig_scale,
                        (
                            f"rank: {self.rank} iter: {iter} expect same scale {orig_scale} "
                            f"but got {_grad_scaler.get_scale()}"
                        ),
                    )
                for param, orig_param in zip(
                    [param for param in _model.parameters() if param.grad is not None],
                    orig_params,
                ):
                    if should_find_inf:
                        self.assertEqual(
                            param,
                            orig_param,
                            (
                                f"rank: {self.rank} iter: {iter} expect the same params before "
                                f"and after optim.step but got {param} vs {orig_param}"
                            ),
                        )
                    else:
                        self.assertNotEqual(
                            param,
                            orig_param,
                            (
                                f"rank: {self.rank} iter: {iter} expect the updated params after "
                                f"optim.step but got {param} vs {orig_param}"
                            ),
                        )
            self.assertEqual(
                scaled_losses[0],
                scaled_losses[1],
                f"iter: {iter} {scaled_losses[0]} vs {scaled_losses[1]}",
            )


instantiate_parametrized_tests(TestShardGradScaler)
instantiate_parametrized_tests(TestShardedGradScalerParityWithDDP)

if __name__ == "__main__":
    run_tests()