1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
|
# Owner(s): ["oncall: distributed"]
import copy
import sys
from collections import OrderedDict
from typing import Any, Dict, List, NamedTuple, Optional, Tuple
import torch
import torch.distributed._shard.sharding_spec as shard_spec
from torch import distributed as dist
from torch.distributed._shard import shard_module
from torch.distributed._shard.sharded_tensor.api import Shard, ShardedTensor
from torch.distributed._shard.sharding_plan import ShardingPlan
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed.fsdp._fsdp_extensions import _set_fsdp_extensions, FSDPExtensions
from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
from torch.distributed.fsdp._utils import _set_fsdp_flattened
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
FullyShardedDataParallel as FSDP,
StateDictType,
)
from torch.distributed.remote_device import _remote_device
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class STShardingInfo(NamedTuple):
""":class:`ShardedTensor` sharding information."""
sharding_spec: shard_spec.ShardingSpec
global_size: torch.Size
process_group: dist.ProcessGroup
class ShardedTensorExtensions(FSDPExtensions):
def pre_flatten_transform(
self,
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[Any]]:
if type(tensor) is ShardedTensor:
param_name_to_sharding_info = STShardingInfo(
tensor.sharding_spec(), tensor.size(), tensor._process_group
)
local_tensor = tensor.local_tensor()
return local_tensor, param_name_to_sharding_info
return tensor, None
def post_unflatten_transform(
self, tensor: torch.Tensor, param_name_to_sharding_info: STShardingInfo
) -> torch.Tensor:
sharded_tensor = ShardedTensor._init_from_local_tensor(
tensor,
_rewrite_spec_if_needed(
param_name_to_sharding_info.sharding_spec,
tensor,
dist.get_rank(param_name_to_sharding_info.process_group),
),
param_name_to_sharding_info.global_size,
process_group=param_name_to_sharding_info.process_group,
)
_set_fsdp_flattened(sharded_tensor)
return sharded_tensor
def chunk_tensor(
self,
tensor: torch.Tensor,
rank: int,
world_size: int,
num_devices_per_node: int,
pg: dist.ProcessGroup,
) -> torch.Tensor:
if type(tensor) is ShardedTensor:
assert len(tensor.local_shards()) == 1
inner_param = tensor.local_tensor()
inner_st = _create_chunk_sharded_tensor(
inner_param,
rank,
world_size,
num_devices_per_node,
pg,
)
outer_local_shard = tensor.local_shards()[0]
shards: List[Shard] = [
Shard(inner_st, copy.deepcopy(outer_local_shard.metadata))
]
st_meta = copy.deepcopy(tensor.metadata())
st_meta.tensor_properties.requires_grad = False
st_outer = ShardedTensor._init_from_local_shards_and_global_metadata(
shards,
sharded_tensor_metadata=st_meta,
process_group=tensor._process_group,
init_rrefs=False,
)
return st_outer
else:
return _create_chunk_sharded_tensor(
tensor, rank, world_size, num_devices_per_node, pg
)
def pre_load_state_dict_transform(
self,
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
shards = tensor.local_shards()
# default impl removes this line
if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor:
tensor = shards[0].tensor
shards = tensor.local_shards()
return (tensor, [shards[0].tensor] if len(shards) > 0 else [])
_set_fsdp_extensions(ShardedTensorExtensions())
def _rewrite_spec_if_needed(
spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int
):
"""
Rewrites ``spec`` to match the device of ``tensor``.
Curerntly, ``FSDP.sharded_optim_state_dict` moves optimizer state to CPU
(without choice), so if the original ``ShardingSpec`` produces non-CPU
metadta, then the ST construction errors.
"""
if not isinstance(spec, ChunkShardingSpec):
return spec
# Determine if we need to rewrite the spec
rewrite = False
for p in spec.placements:
if p.rank() == rank and p.device() != tensor.device:
rewrite = True
break
if rewrite:
spec = copy.deepcopy(spec)
for i, placement in enumerate(spec.placements):
if placement.rank() == rank and placement.device() != tensor.device:
spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}")
return spec
def _is_nested_tensor(val: Any) -> bool:
if type(val) is ShardedTensor:
if len(val.local_shards()) == 0:
return False
if type(val.local_shards()[0].tensor) is ShardedTensor:
return True
return False
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.net1 = torch.nn.Linear(5, 8)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(8, 4)
self.net3 = torch.nn.Linear(4, 12)
def forward(self, x):
return self.net3(self.net2(self.relu(self.net1(x))))
@staticmethod
def module_sharding_plan(specs):
colwise_spec, rowwise_spec = specs[0], specs[1]
return ShardingPlan(
plan={
"net1.weight": colwise_spec,
"net2.weight": rowwise_spec,
},
output_plan={
"net2": colwise_spec,
},
return_local_tensor=["net2"],
)
@staticmethod
def get_sharded_param_names() -> List[str]:
return ["net1.weight", "net2.weight"]
@staticmethod
def get_non_sharded_param_names() -> List[str]:
return ["net3.weight", "net3.bias"]
class TestTPFSDPIntegration(FSDPTest):
def _get_params_and_sharding_info(
self,
model: SimpleModel,
sharded_param_names: List[str],
tensor_parallel_size: int,
) -> Tuple[Dict[str, int], Dict[str, Tuple[torch.Size, int]]]:
""" """
assert (
type(model) is SimpleModel
), "Expects a `SimpleModel` since the sharding cases on the model definition"
param_name_to_numel = OrderedDict()
param_name_to_sharding_info = OrderedDict()
for param_name, param in model.named_parameters():
if param_name not in sharded_param_names:
param_name_to_numel[param_name] = param.numel()
else:
param_name_to_numel[param_name] = param.numel() // tensor_parallel_size
param_name_to_sharding_info[param_name] = (
param.size(),
0 if "net1" in param_name else 1,
)
return param_name_to_numel, param_name_to_sharding_info
def _get_sub_pgs(self, tensor_parallel_size: int):
"""
Generates TP and FSDP subprocess groups. ``tensor_parallel_size`` gives
the TP process group size.
For example, if the global world size is 8 and the tensor parallel size
is 2, then this creates:
- 4 TP subprocess groups: [0, 1], [2, 3], [4, 5], [6, 7]
- 2 FSDP subprocess groups: [0, 2, 4, 6], [1, 3, 5, 7]
"""
tp_ranks: List[List[int]] = []
fsdp_ranks: List[List[int]] = []
for rank in range(self.world_size):
tp_idx = rank // tensor_parallel_size
if len(tp_ranks) <= tp_idx:
tp_ranks.append([])
tp_ranks[tp_idx].append(rank)
fsdp_idx = rank % tensor_parallel_size
if len(fsdp_ranks) <= fsdp_idx:
fsdp_ranks.append([])
fsdp_ranks[fsdp_idx].append(rank)
tp_pgs = [dist.new_group(ranks) for ranks in tp_ranks]
fsdp_pgs = [dist.new_group(ranks) for ranks in fsdp_ranks]
tp_pg = tp_pgs[self.rank // tensor_parallel_size]
fsdp_pg = fsdp_pgs[self.rank % tensor_parallel_size]
return tp_pg, fsdp_pg
def _get_chunk_sharding_spec(self, tp_world_size: int, tp_pg: dist.ProcessGroup):
placements = [
f"rank:{idx}/cuda:{dist.distributed_c10d.get_global_rank(tp_pg, idx) % torch.cuda.device_count()}"
for idx in range(tp_world_size)
]
# Rowwise and colwise sharding are specified with respect to the
# transposed linear weight
colwise_spec = ChunkShardingSpec(dim=0, placements=placements)
rowwise_spec = ChunkShardingSpec(dim=1, placements=placements)
return colwise_spec, rowwise_spec
def _sync_tp_grads(
self,
tp_fsdp_model: FSDP,
tp_pg: dist.ProcessGroup,
param_name_to_numel: Dict[str, int],
non_sharded_param_names: List[str],
) -> None:
"""
Syncs the tensor parallel parameters' gradients following the data
parallel paradigm where gradients are averaged over ranks (in this
case, the ones in the tensor parallel process group).
"""
tp_world_size = tp_pg.size()
fsdp_world_size = self.world_size // tp_world_size
assert (
type(tp_fsdp_model) is FSDP and len(list(tp_fsdp_model.parameters())) == 1
), (
"The following logic assumes a single top-level-only FSDP wrapping "
"the model with TP already applied"
)
flat_param = tp_fsdp_model.params[0]
splits = tuple(param_name_to_numel.values())
# Create a mask over the gradient elements to manually reduce
unsharded_size = torch.Size([flat_param.numel() * fsdp_world_size])
unsharded_zeros = torch.zeros(unsharded_size, device=flat_param.device)
per_param_masks = unsharded_zeros.split(splits)
for param_idx, param_name in enumerate(
param_name_to_numel.keys()
): # assumes fixed order
if param_name not in non_sharded_param_names:
per_param_masks[param_idx][:] = 1
unsharded_mask = torch.cat(per_param_masks).contiguous().type(torch.BoolTensor)
sharded_mask = unsharded_mask.chunk(fsdp_world_size)[self.rank // tp_world_size]
grad_device = flat_param.grad.device
grad = flat_param.grad.detach().clone().cuda(self.rank)
dist.all_reduce(grad, op=dist.ReduceOp.SUM, group=tp_pg)
grad = grad.to(grad_device)
flat_param.grad[~sharded_mask] = grad[~sharded_mask]
# Average *all* gradient elements to match the FSDP only semantics
flat_param.grad /= tp_world_size
def _get_grads_as_flattened(
self,
model: FSDP,
uses_tp: bool,
param_name_to_numel: Dict[str, int],
param_name_to_sharding_info: Dict[str, Tuple[torch.Size, int]],
tp_pg: Optional[dist.ProcessGroup],
fsdp_pg: Optional[dist.ProcessGroup],
sharded_param_names: Optional[List[str]],
) -> torch.Tensor:
"""
Returns all unsharded gradients as a single flattened tensor. This
returns the same value on all ranks.
"""
local_grads_as_flattened = (
torch.cat([torch.flatten(param.grad) for param in model.parameters()])
.contiguous()
.cuda(self.rank)
)
all_grads_as_flattened = torch.cat(
[torch.empty_like(local_grads_as_flattened) for _ in range(fsdp_pg.size())]
).contiguous()
dist._all_gather_base(
all_grads_as_flattened, local_grads_as_flattened, group=fsdp_pg
)
if not uses_tp:
return all_grads_as_flattened
splits = tuple(param_name_to_numel.values())
all_grads_per_param = list(all_grads_as_flattened.split(splits))
for param_idx, param_name in enumerate(
param_name_to_numel.keys()
): # assumes fixed order
if param_name in sharded_param_names:
local_tensor_size = list(param_name_to_sharding_info[param_name][0])
sharding_dim = param_name_to_sharding_info[param_name][1]
local_tensor_size[sharding_dim] //= tp_pg.size()
local_tensor = all_grads_per_param[param_idx].view(*local_tensor_size)
local_tensors = [
torch.empty_like(local_tensor) for _ in range(tp_pg.size())
]
dist.all_gather(local_tensors, local_tensor, group=tp_pg)
all_grads_per_param[param_idx] = torch.cat(
local_tensors, dim=sharding_dim
).reshape(-1)
return torch.cat(all_grads_per_param).contiguous()
@skip_if_lt_x_gpu(4)
@parametrize("tensor_parallel_size", [2, 4])
@parametrize(
"cpu_offload",
[CPUOffload(offload_params=False), CPUOffload(offload_params=True)],
)
def test_fsdp_tp_integration(self, tensor_parallel_size, cpu_offload):
"""
Tests training for TP + FSDP integration by comparing an FSDP-only
model with a TP + FSDP model.
"""
LR = 3e-5
torch.manual_seed(0)
model = SimpleModel().cuda(self.rank)
tp_fsdp_model = copy.deepcopy(model)
sharded_param_names = SimpleModel.get_sharded_param_names()
non_sharded_param_names = SimpleModel.get_non_sharded_param_names()
(
param_name_to_numel,
param_name_to_sharding_info,
) = self._get_params_and_sharding_info(
model,
sharded_param_names,
tensor_parallel_size,
)
input_seed = self.rank
torch.manual_seed(input_seed + 1)
inp_size = [2, 3, 5]
inp = torch.rand(*inp_size).cuda(self.rank)
self.assertEqual(model(inp), tp_fsdp_model(inp)) # sanity check
tp_pg, fsdp_pg = self._get_sub_pgs(tensor_parallel_size)
fsdp_model = FSDP(
model, process_group=self.process_group, cpu_offload=cpu_offload
)
# Shard with TP and then wrap with FSDP
sharding_specs = self._get_chunk_sharding_spec(tp_pg.size(), tp_pg)
sharding_plan = SimpleModel.module_sharding_plan(sharding_specs)
shard_module(tp_fsdp_model, sharding_plan, process_group=tp_pg)
tp_fsdp_model = FSDP(
tp_fsdp_model, process_group=fsdp_pg, cpu_offload=cpu_offload
)
# Check the forward by checking output equality
fsdp_out = fsdp_model(inp)
tp_fsdp_out = tp_fsdp_model(inp)
self.assertEqual(fsdp_out, tp_fsdp_out)
# Check the backward by checking gradient equality
fsdp_out.sum().backward()
tp_fsdp_out.sum().backward()
self._sync_tp_grads(
tp_fsdp_model,
tp_pg,
param_name_to_numel,
non_sharded_param_names,
)
model_grads = self._get_grads_as_flattened(
fsdp_model,
False,
param_name_to_numel,
param_name_to_sharding_info,
None,
self.process_group,
None,
)
model_tp_grads = self._get_grads_as_flattened(
tp_fsdp_model,
True,
param_name_to_numel,
param_name_to_sharding_info,
tp_pg,
fsdp_pg,
sharded_param_names,
)
self.assertEqual(model_grads, model_tp_grads)
# Check the optimizer step by performing a second forward pass
fsdp_optim = torch.optim.SGD(fsdp_model.parameters(), lr=LR)
tp_fsdp_optim = torch.optim.SGD(tp_fsdp_model.parameters(), lr=LR)
fsdp_optim.step()
tp_fsdp_optim.step()
torch.manual_seed(input_seed + 16)
inp = torch.rand(*inp_size).cuda(self.rank)
fsdp_out = fsdp_model(inp)
tp_fsdp_out = tp_fsdp_model(inp)
self.assertEqual(fsdp_out, tp_fsdp_out)
@skip_if_lt_x_gpu(4)
def test_fsdp_tp_checkpoint_integration(self):
"""Tests checkpointing for TP + FSDP integration."""
tensor_parallel_size = 2
torch.manual_seed(0)
model = SimpleModel().cuda(self.rank)
tp_pg, fsdp_pg = self._get_sub_pgs(tensor_parallel_size)
# Shard with TP and then wrap with FSDP
sharding_specs = self._get_chunk_sharding_spec(tp_pg.size(), tp_pg)
sharding_plan = SimpleModel.module_sharding_plan(sharding_specs)
shard_module(model, sharding_plan, process_group=tp_pg)
tp_fsdp_model = FSDP(model, process_group=fsdp_pg)
# Check that we produce a nested ST from model state dict
with FSDP.state_dict_type(tp_fsdp_model, StateDictType.SHARDED_STATE_DICT):
state_dict = tp_fsdp_model.state_dict()
# TODO once 2D is out, validate the nesting
self.assertTrue(_is_nested_tensor(state_dict["net1.weight"]))
self.assertFalse(_is_nested_tensor(state_dict["net1.bias"]))
tp_fsdp_optim = torch.optim.Adam(tp_fsdp_model.parameters(), lr=0.0001)
input_seed = self.rank
torch.manual_seed(input_seed + 1)
inp_size = [2, 3, 5]
inp = torch.rand(*inp_size).cuda(self.rank)
tp_fsdp_model(inp).sum().backward()
tp_fsdp_optim.step()
# Check that we produce a nested ST from optim state dict
optim_state = FSDP.sharded_optim_state_dict(tp_fsdp_model, tp_fsdp_optim)
# TODO once 2D is out, validate the nesting
self.assertTrue(
_is_nested_tensor(optim_state["state"]["net1.weight"]["exp_avg"])
)
self.assertFalse(
_is_nested_tensor(optim_state["state"]["net1.bias"]["exp_avg"])
)
instantiate_parametrized_tests(TestTPFSDPIntegration)
if __name__ == "__main__":
run_tests()
|