File: test_fsdp_fine_tune.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (413 lines) | stat: -rw-r--r-- 15,798 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
# Owner(s): ["oncall: distributed"]

import copy
import sys
from unittest import mock

import torch
import torch.distributed as dist
import torch.nn as nn
from torch._utils import _get_device_module
from torch.distributed.fsdp import BackwardPrefetch, CPUOffload, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest, get_devtype
from torch.testing._internal.common_utils import (
    run_tests,
    TEST_CUDA,
    TEST_WITH_DEV_DBG_ASAN,
)


device_type = torch.device(get_devtype())

if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class LinearUnusedInput(nn.Linear):
    def forward(self, frozen_input, learnable_input):
        return super().forward(frozen_input)


class ModelUnusedInput(nn.Module):
    def __init__(self, freeze: bool):
        super().__init__()
        self.layer0 = LinearUnusedInput(4, 4)
        self.layer1_frozen = LinearUnusedInput(4, 4)
        if freeze:
            for param in self.layer1_frozen.parameters():
                param.requires_grad = False
        self.layer2 = LinearUnusedInput(4, 4)

    def forward(self, frozen_input, learnable_input):
        x = self.layer0(frozen_input, learnable_input)
        y = self.layer1_frozen(frozen_input, learnable_input)
        z = self.layer2(frozen_input, learnable_input)
        return torch.concat([x, y, z, learnable_input])


class TestFSDPFineTune(FSDPTest):
    """Tests fine-tuning cases where some parameters are frozen."""

    NUM_LINEARS = 6

    @property
    def world_size(self) -> int:
        return min(_get_device_module(self.device_type).device_count(), 2)

    def _init_seq_module(self, device) -> nn.Module:
        torch.manual_seed(42)
        modules = []
        for _ in range(self.NUM_LINEARS):
            modules += [nn.Linear(5, 5, device=device), nn.ReLU()]
        seq = nn.Sequential(*modules)
        self._set_seq_module_requires_grad(seq, False)
        return seq

    def _set_seq_module_requires_grad(self, seq: nn.Module, requires_grad: bool):
        # Assume that the linears are leaf modules, meaning that we can pass
        # `recurse=True` to have this to work for both pre/post FSDP wrapping
        for i in range(self.NUM_LINEARS):
            # Only set for every other linear to test mixing frozen/non-frozen
            if i % 2 == 0:
                for param in seq[i * 2].parameters(recurse=True):
                    param.requires_grad = requires_grad

    @skip_if_lt_x_gpu(2)
    def test_backward_reshard_hooks(self, device):
        """
        Tests that the post-backward reshard happens even for flat parameters
        that do not require gradients.
        """
        self.run_subtests(
            {
                "device_id": [device],
                "sharding_strategy": [
                    ShardingStrategy.FULL_SHARD,
                    ShardingStrategy.SHARD_GRAD_OP,
                    ShardingStrategy.NO_SHARD,
                ],
                "use_orig_params": [False, True],
                "inp_requires_grad": [False, True],
                "unfreeze_params": [False, True],
            },
            self._test_backward_reshard_hooks,
        )

    def _test_backward_reshard_hooks(
        self,
        device_id,
        sharding_strategy: ShardingStrategy,
        use_orig_params: bool,
        inp_requires_grad: bool,
        unfreeze_params: bool,
    ):
        seq = self._init_seq_module(device_type)
        policy = ModuleWrapPolicy({nn.Linear})
        fsdp_kwargs = {"device_id": device_type}
        seq = FSDP(
            seq,
            auto_wrap_policy=policy,
            sharding_strategy=sharding_strategy,
            use_orig_params=use_orig_params,
            **fsdp_kwargs,
        )
        orig_post_backward_reshard = (
            torch.distributed.fsdp._runtime_utils._post_backward_reshard
        )
        post_backward_reshard_count = 0

        def _post_backward_reshard_with_count(*args, **kwargs):
            nonlocal post_backward_reshard_count
            post_backward_reshard_count += 1
            return orig_post_backward_reshard(*args, **kwargs)

        def _assert_post_backward_requires_grad(seq):
            if step_idx == num_steps - 1 and unfreeze_params:
                self.assertTrue(
                    all(p.requires_grad for p in seq.parameters()),
                    msg="Expected all parameters to require grad but some did not!",
                )

        def _assert_post_backward_reshard_count(step_idx, num_steps):
            if step_idx < num_steps - 1 or not unfreeze_params:
                # If the input does not require gradient, then the 0th
                # frozen linear gets resharded in the catch-all reshard
                # since we cannot register an autograd hook on it
                expected_post_backward_reshard_count = (
                    self.NUM_LINEARS if inp_requires_grad else self.NUM_LINEARS - 1
                )
            else:
                # This follows the normal post-backward hook path
                expected_post_backward_reshard_count = self.NUM_LINEARS
            self.assertEqual(
                post_backward_reshard_count, expected_post_backward_reshard_count
            )

        with mock.patch(
            "torch.distributed.fsdp._runtime_utils._post_backward_reshard",
            _post_backward_reshard_with_count,
        ):
            num_steps = 3
            # interleave a `no_grad` step to validate post-backward hooks are not registered in that context
            # and that `requires_grad` is reset appropriately when unfreezing
            nograd_step_idx = 1
            for step_idx in range(num_steps):
                if unfreeze_params and step_idx == num_steps - 1:
                    # Unfreeze the parameters on the last step to emulate some
                    # kinds of fine-tuning
                    self._set_seq_module_requires_grad(seq, True)

                inp = torch.randn(
                    (8, 5), device=device_type, requires_grad=inp_requires_grad
                )
                if step_idx == nograd_step_idx:
                    with torch.no_grad():
                        output = seq(inp)
                else:
                    output = seq(inp)
                if step_idx != nograd_step_idx:
                    output.sum().backward()
                    _assert_post_backward_requires_grad(seq)
                    _assert_post_backward_reshard_count(step_idx, num_steps)
                    post_backward_reshard_count = 0

    def _init_multi_traversal_module(self, device) -> nn.Module:
        torch.manual_seed(42)

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.layer_0 = nn.Linear(5, 5, device=device)
                self.layer_no_grad = nn.Linear(5, 5, device=device)
                self.layer_with_grad = nn.Linear(5, 5, device=device)
                self.layer_no_grad.requires_grad_(False)

            def forward(self, x):
                # Layer `layer_no_grad` and `layer_with_grad` are called
                # multiple times, IOW, their parameters are used multiple times
                # during forward pass.
                x = self.layer_0(x)
                for _ in range(10):
                    x = self.layer_no_grad(self.layer_with_grad(x))
                    # Make sure calling the same layer multiple times works
                    # regardless whether gradient is enabled.
                    with torch.no_grad():
                        x += self.layer_with_grad(x)
                return x

        return TestModule()

    @skip_if_lt_x_gpu(2)
    def test_hooks_multi_traversal(self):
        """
        Tests that the hooks do reshard / unshard correctly in the case of same
        parameters being used multiple times during forward pass.
        """
        self.run_subtests(
            {
                "sharding_strategy": [
                    ShardingStrategy.FULL_SHARD,
                    ShardingStrategy.SHARD_GRAD_OP,
                    ShardingStrategy.NO_SHARD,
                ],
                "use_orig_params": [False, True],
                "inp_requires_grad": [False, True],
                "forward_prefetch": [False, True],
            },
            self._test_hooks_multi_traversal,
        )

    def _test_hooks_multi_traversal(
        self,
        sharding_strategy: ShardingStrategy,
        use_orig_params: bool,
        inp_requires_grad: bool,
        forward_prefetch: bool,
    ):
        seq = self._init_multi_traversal_module(device_type.type)
        policy = ModuleWrapPolicy({nn.Linear})
        fsdp_kwargs = {"device_id": device_type}
        fsdp_seq = FSDP(
            copy.deepcopy(seq),
            auto_wrap_policy=policy,
            sharding_strategy=sharding_strategy,
            use_orig_params=use_orig_params,
            forward_prefetch=forward_prefetch,
            **fsdp_kwargs,
        )
        ddp_seq = DDP(copy.deepcopy(seq), device_ids=[device_type])
        fsdp_optim = torch.optim.Adam(fsdp_seq.parameters(), lr=1e-2)
        ddp_optim = torch.optim.Adam(ddp_seq.parameters(), lr=1e-2)
        torch.manual_seed(self.rank + 1)
        losses = []
        for _ in range(6):
            inp = torch.randn(
                (8, 5), device=device_type, requires_grad=inp_requires_grad
            )
            for seq, optim in ((fsdp_seq, fsdp_optim), (ddp_seq, ddp_optim)):
                loss = seq(inp).sum()
                losses.append(loss)
                loss.backward()
                optim.step()
                optim.zero_grad()
            torch.testing.assert_close(losses[0], losses[1])
            losses.clear()

    @skip_if_lt_x_gpu(2)
    def test_parity_with_ddp(self):
        """
        Tests parity with DDP when mixing flat parameters that require and do
        not require gradients.
        """
        self.run_subtests(
            {
                "sharding_strategy": [
                    ShardingStrategy.FULL_SHARD,
                    ShardingStrategy.SHARD_GRAD_OP,
                    ShardingStrategy.NO_SHARD,
                ],
                "use_orig_params": [False, True],
            },
            self._test_parity_with_ddp,
        )

    def _test_parity_with_ddp(
        self,
        sharding_strategy: ShardingStrategy,
        use_orig_params: bool,
    ):
        seq = self._init_seq_module(device_type)
        policy = ModuleWrapPolicy({nn.Linear})
        fsdp_kwargs = {"device_id": device_type}
        fsdp_seq = FSDP(
            copy.deepcopy(seq),
            auto_wrap_policy=policy,
            sharding_strategy=sharding_strategy,
            use_orig_params=use_orig_params,
            **fsdp_kwargs,
        )
        ddp_seq = DDP(copy.deepcopy(seq), device_ids=[device_type])
        fsdp_optim = torch.optim.Adam(fsdp_seq.parameters(), lr=1e-2)
        ddp_optim = torch.optim.Adam(ddp_seq.parameters(), lr=1e-2)
        torch.manual_seed(self.rank + 1)
        losses = []
        for _ in range(6):
            inp = torch.randn((8, 5), device=device_type.type)
            for seq, optim in ((fsdp_seq, fsdp_optim), (ddp_seq, ddp_optim)):
                loss = seq(inp).sum()
                losses.append(loss)
                loss.backward()
                optim.step()
                optim.zero_grad()
            if TEST_CUDA:
                torch.testing.assert_close(losses[0], losses[1])
            else:
                torch.testing.assert_close(losses[0], losses[1], atol=1e-03, rtol=1e-03)
            losses.clear()

    @skip_if_lt_x_gpu(2)
    def test_parity_with_non_frozen_fsdp(self, device):
        """
        For frozen modules with unused input, reshard could happen without unshard
        Verify numerical parity between `_post_backward_reshard_only_hook` and
        `_post_backward_hook` path
        """
        self.run_subtests(
            {
                "device_id": [device],
                "sharding_strategy": [
                    ShardingStrategy.FULL_SHARD,
                    ShardingStrategy.SHARD_GRAD_OP,
                ],
                "use_orig_params": [True, False],
                "offload_params": [True, False],
                "mixed_precision": [
                    MixedPrecision(),
                    MixedPrecision(
                        param_dtype=torch.float16,
                        buffer_dtype=torch.float16,
                        reduce_dtype=torch.float16,
                    ),
                ],
                "backward_prefetch": [
                    BackwardPrefetch.BACKWARD_PRE,
                    BackwardPrefetch.BACKWARD_POST,
                ],
            },
            self._test_parity_with_non_frozen_fsdp,
        )

    def _test_parity_with_non_frozen_fsdp(
        self,
        device_id,
        sharding_strategy: ShardingStrategy,
        use_orig_params: bool,
        offload_params: bool,
        mixed_precision: MixedPrecision,
        backward_prefetch: BackwardPrefetch,
    ):
        torch.manual_seed(42)
        model = ModelUnusedInput(freeze=True).to(device_type)
        torch.manual_seed(42)
        ref_model = ModelUnusedInput(freeze=False).to(device_type)
        fsdp_kwargs = {
            "device_id": device_type,
            "auto_wrap_policy": ModuleWrapPolicy({LinearUnusedInput}),
            "sharding_strategy": sharding_strategy,
            "use_orig_params": use_orig_params,
            "cpu_offload": CPUOffload(offload_params=offload_params),
            "mixed_precision": mixed_precision,
            "backward_prefetch": backward_prefetch,
        }
        model = FSDP(model, **fsdp_kwargs)
        ref_model = FSDP(ref_model, **fsdp_kwargs)
        model_optim = torch.optim.Adam(model.parameters(), lr=1e-2)
        ref_model_optim = torch.optim.Adam(
            [
                param
                for name, param in ref_model.named_parameters()
                if not name.startswith("_fsdp_wrapped_module.layer1_frozen")
            ],
            lr=1e-2,
        )
        torch.manual_seed(self.rank + 1)
        losses = []
        for idx in range(6):
            frozen_input = torch.randn((4, 4), device=device_type, requires_grad=False)
            learnable_input = torch.randn(
                (4, 4), device=device_type, requires_grad=True
            )
            for _model, _optim in ((model, model_optim), (ref_model, ref_model_optim)):
                loss = _model(frozen_input, frozen_input).sum()
                losses.append(loss)
                loss.backward()
                _optim.step()
                _optim.zero_grad()
            self.assertEqual(losses[0], losses[1])
            losses.clear()
        with FSDP.summon_full_params(model):
            with FSDP.summon_full_params(ref_model):
                for param, ref_param in zip(model.parameters(), ref_model.parameters()):
                    self.assertEqual(param, ref_param)


devices = ("cuda", "hpu")
instantiate_device_type_tests(TestFSDPFineTune, globals(), only_for=devices)
if __name__ == "__main__":
    run_tests()