1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
# Owner(s): ["oncall: distributed"]
import torch
import torch.nn as nn
from torch.distributed._tensor import DTensor
from torch.distributed.checkpoint.state_dict import get_state_dict
from torch.distributed.device_mesh import _mesh_resources, init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
from torch.testing._internal.distributed.common_state_dict import VerifyStateDictMixin
class Dummymodel(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
raise NotImplementedError
class EPModel(nn.Module):
def __init__(self, rank):
super().__init__()
self.net1 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
def forward(self, x):
raise NotImplementedError
class SecondTier(nn.Module):
def __init__(self, rank):
super().__init__()
self.ep_layers = nn.ModuleList(
[EPModel(rank) if rank % 4 == i else Dummymodel() for i in range(4)]
)
self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
def forward(self, x):
raise NotImplementedError
class TopModel(nn.Module):
def __init__(self, rank):
super().__init__()
torch.manual_seed(0)
self.second = SecondTier(rank)
self.net = nn.Sequential(nn.Linear(16, 16), nn.ReLU())
def forward(self, x):
raise NotImplementedError
class TestFSDPWithEP(DTensorTestBase, VerifyStateDictMixin):
@property
def world_size(self) -> int:
return min(8, torch.cuda.device_count())
@with_comms
@skip_if_lt_x_gpu(8)
@with_temp_dir
def test_e2e(self):
model = TopModel(self.rank).cuda()
mesh_fsdp_tp = init_device_mesh(
self.device_type, (2, 4), mesh_dim_names=("dp", "tp")
)
# TODO: we are using an internal API atm. Change to a publich API once it is ready.
mesh_fsdp_ep = _mesh_resources.create_child_mesh(mesh_fsdp_tp, ("dp",))
del _mesh_resources.child_to_parent_mapping[mesh_fsdp_ep]
mesh_fsdp = init_device_mesh(self.device_type, (8,))
for i, l in enumerate(model.second.ep_layers):
model.second.ep_layers[i] = FSDP(
l, use_orig_params=True, device_mesh=mesh_fsdp_ep
)
model.second = FSDP(model.second, use_orig_params=True, device_mesh=mesh_fsdp)
model = FSDP(model, use_orig_params=True, device_mesh=mesh_fsdp)
optim = torch.optim.Adam(model.parameters(), lr=0.1)
msd, osd = get_state_dict(model, optim)
# FSDP only params
for key in (
"net.0.weight",
"net.0.bias",
"second.net.0.weight",
"second.net.0.bias",
):
msd_v = msd[key]
osd_v = osd["state"][key]["exp_avg"]
for v in (msd_v, osd_v):
self.assertTrue(isinstance(v, DTensor))
self.assertEqual(tuple(v.device_mesh.mesh), tuple(range(8)))
# FSDP/EP params
layer = self.rank % 4
ranks = (layer, layer + 4)
for i in range(4):
for key in (
f"second.ep_layers.{i}.net1.0.weight",
f"second.ep_layers.{i}.net1.0.bias",
f"second.ep_layers.{i}.net2.0.weight",
f"second.ep_layers.{i}.net2.0.bias",
):
if layer != i:
self.assertTrue(key not in msd)
else:
msd_v = msd[key]
osd_v = osd["state"][key]["exp_avg"]
for v in (msd_v, osd_v):
self.assertTrue(isinstance(v, DTensor))
self.assertEqual(tuple(v.device_mesh.mesh), ranks)
self.assertEqual(set(osd["state"].keys()), set(msd.keys()))
if __name__ == "__main__":
run_tests()
|