1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
# Owner(s): ["oncall: distributed"]
import sys
import torch
import torch.distributed.fsdp._traversal_utils as traversal_utils
from torch import distributed as dist
from torch.distributed.fsdp import (
CPUOffload,
FullyShardedDataParallel as FSDP,
MixedPrecision,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
DEVICEInitMode,
FSDPInitMode,
FSDPTest,
get_devtype,
NestedWrappedModule,
)
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
device_type = torch.device(get_devtype())
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestPureFP16(FSDPTest):
@skip_if_lt_x_gpu(2)
def test_pure_fp16_training(self):
"""Tests pure FP16 training, including when the parameter's dtype is
changed after FSDP initialization and before training."""
self.run_subtests(
{
"cpu_offload": [
CPUOffload(offload_params=True),
CPUOffload(offload_params=False),
]
},
self._test_pure_fp16_training,
)
def _test_pure_fp16_training(self, cpu_offload: CPUOffload):
self._test_fsdp_parity(
NestedWrappedModule,
FSDPInitMode.RECURSIVE,
device_init_mode=DEVICEInitMode.DEVICE_BEFORE,
# Run one iteration to avoid NaN without a gradient scaler
num_iters=1,
cpu_offload=cpu_offload,
use_pure_fp16=True,
)
@skip_if_lt_x_gpu(2)
def test_fp16_dtypes(self):
"""
Tests that both user-facing parameter/gradient dtypes and internal
saved dtype attributes are as expected when using an FP16 model
possibly with explicit mixed precision enabled.
"""
self.run_subtests(
{
"to_half_before_fsdp_init": [False, True],
"use_orig_params": [False, True],
"mixed_precision": [
MixedPrecision(),
MixedPrecision(
param_dtype=torch.float16,
reduce_dtype=torch.float32,
),
MixedPrecision(
param_dtype=torch.float32,
),
],
},
self._test_fp16_dtypes,
)
def _test_fp16_dtypes(
self,
to_half_before_fsdp_init: bool,
use_orig_params: bool,
mixed_precision: MixedPrecision,
):
model = NestedWrappedModule.init(
self.process_group,
FSDPInitMode.NO_FSDP,
DEVICEInitMode.DEVICE_NEVER,
{
"device_id": device_type,
},
)
fsdp_kwargs = {
"use_orig_params": use_orig_params,
"device_id": device_type,
"mixed_precision": mixed_precision,
}
if to_half_before_fsdp_init:
model = model.half()
fsdp_model = FSDP(model, **fsdp_kwargs)
if not to_half_before_fsdp_init:
fsdp_model = fsdp_model.half()
for param in fsdp_model.parameters():
self.assertEqual(param.dtype, torch.float16)
inp = tuple(
t.half() if torch.is_tensor(t) else t
for t in fsdp_model.module.get_input(self.device_type)
)
out = fsdp_model(*inp)
out.sum().backward()
# Check handle dtype attributes
for handle in traversal_utils._get_fsdp_handles(fsdp_model):
self.assertEqual(handle.flat_param.dtype, torch.float16)
self.assertEqual(handle.flat_param.grad.dtype, torch.float16)
self.assertEqual(handle._orig_param_dtype, torch.float16)
# Specifying `mixed_precision` takes precedence over the model
# dtype for both `param_dtype` and `reduce_dtype`
if mixed_precision.param_dtype is not None:
self.assertEqual(
handle._fwd_bwd_param_dtype, mixed_precision.param_dtype
)
else:
self.assertEqual(handle._fwd_bwd_param_dtype, torch.float16)
if mixed_precision.reduce_dtype is not None:
self.assertEqual(handle._reduce_dtype, mixed_precision.reduce_dtype)
elif (
mixed_precision.reduce_dtype is None
and mixed_precision.param_dtype is not None
):
# Special case: infer reduce dtype from parameter dtype
self.assertEqual(handle._reduce_dtype, mixed_precision.param_dtype)
else:
self.assertEqual(handle._reduce_dtype, torch.float16)
# Check parameter/gradient dtypes
for param in fsdp_model.parameters():
self.assertEqual(param.dtype, torch.float16)
if param.grad is not None:
self.assertEqual(param.grad.dtype, torch.float16)
devices = ("cuda", "hpu")
instantiate_device_type_tests(TestPureFP16, globals(), only_for=devices)
if __name__ == "__main__":
run_tests()
|