1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
|
# Owner(s): ["oncall: distributed"]
import contextlib
import copy
import functools
import math
import threading
import unittest
from typing import Any, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.utils._pytree as pytree
from torch.autograd.grad_mode import _unsafe_preserve_version_counter
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
check_sharded_parity,
FSDPTest,
FSDPTestMultiThread,
MLP,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.two_tensor import TwoTensor
def two_tensor_fsdp_pre_all_gather_v1(
self, mesh: DeviceMesh
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
all_gather_inputs = (self.a, self.b)
metadata = None
return all_gather_inputs, metadata
def two_tensor_fsdp_pre_all_gather_v2(
self,
mesh: DeviceMesh,
outer_size: torch.Size,
outer_stride: Tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
all_gather_inputs = (self.a, self.b)
metadata = None
return all_gather_inputs, metadata
def two_tensor_fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
assert metadata is None, f"{metadata}"
a, b = all_gather_outputs
if out is not None:
assert isinstance(out, TwoTensor), f"{type(out)}"
if a.dtype == param_dtype:
assert a.untyped_storage().data_ptr() == out.a.untyped_storage().data_ptr()
assert b.untyped_storage().data_ptr() == out.b.untyped_storage().data_ptr()
else:
assert out.a.dtype == param_dtype, f"{out.a.dtype} {param_dtype}"
assert out.b.dtype == param_dtype, f"{out.b.dtype} {param_dtype}"
out.a.copy_(a)
out.b.copy_(b)
return
tensors_to_free = (a, b)
# If the cast is real, then the all-gather outputs will not alias the
# returned `TwoTensor`'s `a` and `b`
two_tensor = TwoTensor(a, b).to(param_dtype)
return two_tensor, tensors_to_free
class BFloat16AllGatherTensor(torch.Tensor):
@staticmethod
def __new__(cls, data: torch.Tensor, pad_in_pre_all_gather: bool = True):
return torch.Tensor._make_wrapper_subclass(
cls,
data.shape,
data.stride(),
data.storage_offset(),
dtype=data.dtype,
device=data.device,
)
def __init__(self, data: torch.Tensor, pad_in_pre_all_gather: bool = True):
self._data = data
self._pad_in_pre_all_gather = pad_in_pre_all_gather
def fsdp_pre_all_gather(
self,
mesh: DeviceMesh,
outer_size: torch.Size,
outer_stride: Tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
assert mesh.ndim == 1, f"{mesh.ndim}"
mesh_size = mesh.size()
requires_padding = outer_size[0] % mesh_size != 0
if requires_padding and self._pad_in_pre_all_gather:
sharded_padded_size = list(outer_size)
sharded_padded_size[0] = math.ceil(outer_size[0] / mesh_size)
padded_out = torch.empty(
sharded_padded_size, dtype=torch.bfloat16, device=self.device
)
padded_out[: self._data.size(0)].copy_(self._data)
return (padded_out,), None
else:
return self._data.to(torch.bfloat16), None
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
assert metadata is None, f"{metadata}"
(tensor,) = all_gather_outputs
assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}"
if out is not None:
with _unsafe_preserve_version_counter(out):
out.copy_(tensor)
return
upcast_tensor = tensor.to(param_dtype)
return upcast_tensor, (tensor, upcast_tensor)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
pad_in_pre_all_gather = None
def unwrap(x: cls):
nonlocal pad_in_pre_all_gather
if pad_in_pre_all_gather is None:
pad_in_pre_all_gather = x._pad_in_pre_all_gather
else:
assert pad_in_pre_all_gather == x._pad_in_pre_all_gather
return x._data
out = func(
*pytree.tree_map_only(cls, unwrap, args),
**pytree.tree_map_only(cls, unwrap, kwargs),
)
return pytree.tree_map_only(
torch.Tensor, lambda x: cls(x, pad_in_pre_all_gather), out
)
def __tensor_flatten__(self):
return ["_data"], None
@staticmethod
def __tensor_unflatten__(
inner_tensors, outer_size: torch.Size, outer_stride: Tuple[int, ...]
):
return inner_tensors["_data"]
def __repr__(self):
return f"{self.__class__.__name__}({self._data})"
class TestFullyShardAllGatherExtensionsCommon:
@property
def world_size(self) -> int:
return 2
@contextlib.contextmanager
def _patch_two_tensor_fsdp_all_gather(self, pre_all_gather_version: int):
lock = threading.Lock()
if pre_all_gather_version == 1:
TwoTensor.fsdp_pre_all_gather = two_tensor_fsdp_pre_all_gather_v1
elif pre_all_gather_version == 2:
TwoTensor.fsdp_pre_all_gather = two_tensor_fsdp_pre_all_gather_v2
TwoTensor.fsdp_post_all_gather = two_tensor_fsdp_post_all_gather
dist.barrier()
try:
yield
finally:
dist.barrier()
with lock: # only one thread needs to delete
if hasattr(TwoTensor, "fsdp_pre_all_gather"):
delattr(TwoTensor, "fsdp_pre_all_gather")
if hasattr(TwoTensor, "fsdp_post_all_gather"):
delattr(TwoTensor, "fsdp_post_all_gather")
def _init_two_tensor_mlp(self) -> nn.Module:
# Disable bias because the reference model will end up with a bias
# gradient that is a `TwoTensor`, whereas the FSDP model does not
model = nn.Sequential(*[MLP(8, bias=False) for _ in range(3)])
for mlp in model:
mlp.in_proj.weight = nn.Parameter(
TwoTensor(mlp.in_proj.weight, mlp.in_proj.weight.clone())
)
mlp.out_proj.weight = nn.Parameter(
TwoTensor(mlp.out_proj.weight, mlp.out_proj.weight.clone())
)
return model
class TestFullyShardAllGatherExtensionsMultiProcess(
TestFullyShardAllGatherExtensionsCommon, FSDPTest
):
@skip_if_lt_x_gpu(2)
def test_all_gather_extensions_train_parity(self):
with self._patch_two_tensor_fsdp_all_gather(pre_all_gather_version=1):
self.run_subtests(
{"reshard_after_forward": [True, False]},
self._test_all_gather_extensions_train_parity,
)
with self._patch_two_tensor_fsdp_all_gather(pre_all_gather_version=2):
self.run_subtests(
{"reshard_after_forward": [True, False]},
self._test_all_gather_extensions_train_parity,
)
def _test_all_gather_extensions_train_parity(self, reshard_after_forward: bool):
torch.manual_seed(42)
model = self._init_two_tensor_mlp()
ref_model = copy.deepcopy(model).cuda()
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=True)
fully_shard_fn = functools.partial(
fully_shard, reshard_after_forward=reshard_after_forward
)
for mlp in model:
fully_shard_fn(mlp)
fully_shard_fn(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
check_sharded_parity(self, ref_model, model)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, 8), device="cuda")
for iter_idx in range(10):
losses: List[torch.Tensor] = []
for _model in (ref_model, model):
losses.append(_model(inp).sum())
losses[-1].backward()
if _model is ref_model:
for param_name, param in _model.named_parameters():
dist.all_reduce(param.grad)
param.grad.detach().div_(self.world_size)
self.assertEqual(losses[0], losses[1])
check_sharded_parity(self, ref_model, model)
for _optim in (ref_optim, optim):
_optim.step()
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
check_sharded_parity(self, ref_model, model)
class TestFullyShardAllGatherExtensionsMultiThread(
TestFullyShardAllGatherExtensionsCommon, FSDPTestMultiThread
):
@property
def world_size(self) -> int:
return 8
@property
def device(self) -> torch.device:
return torch.device("cuda:0")
@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_all_gather_extensions_end_to_end(self):
with self._patch_two_tensor_fsdp_all_gather(pre_all_gather_version=1):
self.run_subtests(
{"reshard_after_forward": [True, False]},
self._test_all_gather_extensions_end_to_end,
)
with self._patch_two_tensor_fsdp_all_gather(pre_all_gather_version=2):
self.run_subtests(
{"reshard_after_forward": [True, False]},
self._test_all_gather_extensions_end_to_end,
)
def _test_all_gather_extensions_end_to_end(self, reshard_after_forward: bool):
# Check that we can run the meta-device initialization flow
with torch.device("meta"):
model = self._init_two_tensor_mlp()
for param in model.parameters():
self.assertEqual(param.device, torch.device("meta"))
fully_shard_fn = functools.partial(
fully_shard,
reshard_after_forward=reshard_after_forward,
mp_policy=MixedPrecisionPolicy(param_dtype=torch.bfloat16),
)
for mlp in model:
fully_shard_fn(mlp)
fully_shard_fn(model)
model.to_empty(device=self.device)
for param in model.parameters():
nn.init.trunc_normal_(param)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
# Run a few iterations to check for errors
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, 8), device="cuda")
for _ in range(3):
model(inp).sum().backward()
optim.step()
optim.zero_grad()
@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_all_gather_extensions_monkey_patch(self):
tls = threading.local()
tls.ran_pre_all_gather = False
# Define a pre/post-all-gather pair that quantizes to bf16 for the
# all-gather and de-quantizes back to the parameter dtype
def fsdp_pre_all_gather(
self,
mesh: DeviceMesh,
outer_size: torch.Size,
outer_stride: Tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
nonlocal tls
tls.ran_pre_all_gather = True
return (self.to(torch.bfloat16),), None
@torch.no_grad()
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
(tensor,) = all_gather_outputs
assert metadata is None, f"{metadata}"
assert tensor.dtype == torch.bfloat16, f"{tensor.dtype}"
if out is not None:
with _unsafe_preserve_version_counter(out):
out.copy_(tensor)
return
upcast_tensor = tensor.to(param_dtype)
return upcast_tensor, (tensor, upcast_tensor)
with torch.device("meta"):
model = self._init_two_tensor_mlp()
for mlp in model:
fully_shard(mlp)
fully_shard(model)
model.to_empty(device=self.device)
for param in model.parameters():
nn.init.trunc_normal_(param)
# Monkey patch the pre/post-all-gather functions *after* `to_empty()`
# since the local tensor objects change from materialization
self.assertGreater(sum("weight" in n for n, _ in model.named_parameters()), 0)
for param_name, param in model.named_parameters():
if "weight" in param_name:
# Need to use `_local_tensor` to patch the tensor object
local_param = param._local_tensor
# Monkey patch on the `torch.Tensor` as instance methods to
# show that the extension can work even without a subclass
local_param.fsdp_pre_all_gather = fsdp_pre_all_gather.__get__(
local_param
)
local_param.fsdp_post_all_gather = fsdp_post_all_gather.__get__(
local_param
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
# Run a few iterations to check for errors
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, 8), device="cuda")
for _ in range(3):
model(inp).sum().backward()
optim.step()
optim.zero_grad()
assert tls.ran_pre_all_gather
@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_all_gather_extension_outer_size_stride(self):
"""
NOTE: We cannot easily test the incorrect case where the user-defined
``fsdp_pre_all_gather`` does not correctly pad the local tensor because
only some ranks may require padding, in which case only those ranks
will error out and the all-gather will timeout.
"""
assert (
self.world_size >= 2
), f"Assumes world size of at least 2 but got {self.world_size=}"
model = MLP(dim=3, dim_multiplier=3)
for module in model.modules():
for param_name, param in module.named_parameters(recurse=False):
if "weight" in param_name:
param = nn.Parameter(BFloat16AllGatherTensor(param))
setattr(module, param_name, param)
fully_shard(model)
optim = torch.optim.AdamW(model.parameters(), lr=1e-2, fused=True)
torch.manual_seed(42 + self.rank + 1)
inp = torch.randn((2, 3), device="cuda")
loss = model(inp).sum()
loss.backward()
optim.step()
optim.zero_grad()
@unittest.skipIf(not TEST_CUDA, "no cuda")
def test_all_gather_extension_hsdp_mesh(self):
tls = threading.local()
replicate_size = 2
shard_size = self.world_size // replicate_size
mesh = init_device_mesh(
"cuda",
(replicate_size, shard_size),
mesh_dim_names=("dp_replicate", "dp_shard"),
)
def fsdp_pre_all_gather(
self,
mesh: DeviceMesh,
outer_size: torch.Size,
outer_stride: Tuple[int, ...],
module: nn.Module,
mp_policy: MixedPrecisionPolicy,
) -> Tuple[Tuple[torch.Tensor, ...], Any]:
nonlocal tls
tls.mesh = mesh
return (self,), None
@torch.no_grad()
def fsdp_post_all_gather(
self,
all_gather_outputs: Tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: Optional[torch.Tensor] = None,
) -> Union[Tuple[torch.Tensor, Tuple[torch.Tensor, ...]], None]:
(tensor,) = all_gather_outputs
if out is not None:
return
return tensor, (tensor,)
model = self._init_two_tensor_mlp()
for mlp in model:
fully_shard(mlp, mesh=mesh)
fully_shard(model, mesh=mesh)
self.assertGreater(sum("weight" in n for n, _ in model.named_parameters()), 0)
for param_name, param in model.named_parameters():
if "weight" in param_name:
# Need to use `_local_tensor` to patch the tensor object
local_param = param._local_tensor
# Monkey patch on the `torch.Tensor` as instance methods to
# show that the extension can work even without a subclass
local_param.fsdp_pre_all_gather = fsdp_pre_all_gather.__get__(
local_param
)
local_param.fsdp_post_all_gather = fsdp_post_all_gather.__get__(
local_param
)
inp = torch.randn((2, 8), device="cuda")
model(inp)
# Check that FSDP passes only the shard mesh to the pre-all-gather
self.assertEqual(tls.mesh.ndim, 1)
self.assertEqual(tls.mesh.size(), shard_size)
if __name__ == "__main__":
run_tests()
|