File: test_fsdp_pure_fp16.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (157 lines) | stat: -rw-r--r-- 5,497 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Owner(s): ["oncall: distributed"]

import sys

import torch
import torch.distributed.fsdp._traversal_utils as traversal_utils
from torch import distributed as dist
from torch.distributed.fsdp import (
    CPUOffload,
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
    DEVICEInitMode,
    FSDPInitMode,
    FSDPTest,
    get_devtype,
    NestedWrappedModule,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN


device_type = torch.device(get_devtype())

if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


class TestPureFP16(FSDPTest):
    @skip_if_lt_x_gpu(2)
    def test_pure_fp16_training(self):
        """Tests pure FP16 training, including when the parameter's dtype is
        changed after FSDP initialization and before training."""
        self.run_subtests(
            {
                "cpu_offload": [
                    CPUOffload(offload_params=True),
                    CPUOffload(offload_params=False),
                ]
            },
            self._test_pure_fp16_training,
        )

    def _test_pure_fp16_training(self, cpu_offload: CPUOffload):
        self._test_fsdp_parity(
            NestedWrappedModule,
            FSDPInitMode.RECURSIVE,
            device_init_mode=DEVICEInitMode.DEVICE_BEFORE,
            # Run one iteration to avoid NaN without a gradient scaler
            num_iters=1,
            cpu_offload=cpu_offload,
            use_pure_fp16=True,
        )

    @skip_if_lt_x_gpu(2)
    def test_fp16_dtypes(self):
        """
        Tests that both user-facing parameter/gradient dtypes and internal
        saved dtype attributes are as expected when using an FP16 model
        possibly with explicit mixed precision enabled.
        """
        self.run_subtests(
            {
                "to_half_before_fsdp_init": [False, True],
                "use_orig_params": [False, True],
                "mixed_precision": [
                    MixedPrecision(),
                    MixedPrecision(
                        param_dtype=torch.float16,
                        reduce_dtype=torch.float32,
                    ),
                    MixedPrecision(
                        param_dtype=torch.float32,
                    ),
                ],
            },
            self._test_fp16_dtypes,
        )

    def _test_fp16_dtypes(
        self,
        to_half_before_fsdp_init: bool,
        use_orig_params: bool,
        mixed_precision: MixedPrecision,
    ):
        model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            DEVICEInitMode.DEVICE_NEVER,
            {
                "device_id": device_type,
            },
        )
        fsdp_kwargs = {
            "use_orig_params": use_orig_params,
            "device_id": device_type,
            "mixed_precision": mixed_precision,
        }
        if to_half_before_fsdp_init:
            model = model.half()
        fsdp_model = FSDP(model, **fsdp_kwargs)
        if not to_half_before_fsdp_init:
            fsdp_model = fsdp_model.half()
        for param in fsdp_model.parameters():
            self.assertEqual(param.dtype, torch.float16)
        inp = tuple(
            t.half() if torch.is_tensor(t) else t
            for t in fsdp_model.module.get_input(self.device_type)
        )
        out = fsdp_model(*inp)
        out.sum().backward()

        # Check handle dtype attributes
        for handle in traversal_utils._get_fsdp_handles(fsdp_model):
            self.assertEqual(handle.flat_param.dtype, torch.float16)
            self.assertEqual(handle.flat_param.grad.dtype, torch.float16)
            self.assertEqual(handle._orig_param_dtype, torch.float16)
            # Specifying `mixed_precision` takes precedence over the model
            # dtype for both `param_dtype` and `reduce_dtype`
            if mixed_precision.param_dtype is not None:
                self.assertEqual(
                    handle._fwd_bwd_param_dtype, mixed_precision.param_dtype
                )
            else:
                self.assertEqual(handle._fwd_bwd_param_dtype, torch.float16)
            if mixed_precision.reduce_dtype is not None:
                self.assertEqual(handle._reduce_dtype, mixed_precision.reduce_dtype)
            elif (
                mixed_precision.reduce_dtype is None
                and mixed_precision.param_dtype is not None
            ):
                # Special case: infer reduce dtype from parameter dtype
                self.assertEqual(handle._reduce_dtype, mixed_precision.param_dtype)
            else:
                self.assertEqual(handle._reduce_dtype, torch.float16)

        # Check parameter/gradient dtypes
        for param in fsdp_model.parameters():
            self.assertEqual(param.dtype, torch.float16)
            if param.grad is not None:
                self.assertEqual(param.grad.dtype, torch.float16)


devices = ("cuda", "hpu")
instantiate_device_type_tests(TestPureFP16, globals(), only_for=devices)
if __name__ == "__main__":
    run_tests()