1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed.checkpoint as dist_cp
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._state_dict_utils import _all_gather_sharded_tensor
from torch.distributed._tensor import DTensor, init_device_mesh, Replicate
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
MLPModule,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
# TODO: modularize this test and add test for checkpoint conversion in both direction.
class TestFsdpTpCheckpointConversion(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
@with_temp_dir
def test_fsdp_to_tp(self):
CHECKPOINT_DIR = self.temp_dir
model = MLPModule(self.device_type).cuda(self.rank)
# create a FSDP wrapped model
fsdp_model = FSDP(model, use_orig_params=True)
FSDP.set_state_dict_type(
fsdp_model,
StateDictType.SHARDED_STATE_DICT,
)
fsdp_state_dict = fsdp_model.state_dict()
# save fsdp_state_dict to storage
dist_cp.save(
state_dict=fsdp_state_dict,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
)
# create a TP wrapped model
mesh_shape = (self.world_size,)
device_mesh = init_device_mesh(self.device_type, mesh_shape)
model = MLPModule(self.device_type).cuda(self.rank)
# Parallelize the module based on the given Parallel Style.
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
tp_model = parallelize_module(model, device_mesh, parallelize_plan)
optimizer = torch.optim.SGD(tp_model.parameters(), lr=0.25)
# Update the parameters so tp_model.state_dict() will be different from fsdp_model.state_dict().
torch.manual_seed(0)
inp = torch.rand(20, 10).cuda(self.rank)
output = tp_model(inp)
output.sum().backward()
optimizer.step()
tp_state_dict = tp_model.state_dict()
# Check parameters are indeed different prior to loading.
for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()):
fsdp_k, fsdp_v = fsdp_item
tp_k, tp_v = tp_item
self.assertEqual(fsdp_k, tp_k)
if isinstance(fsdp_v, ShardedTensor) and isinstance(tp_v, DTensor):
fsdp_redistributed = _all_gather_sharded_tensor(fsdp_v)
tp_redistributed = tp_v.redistribute(
device_mesh, placements=[Replicate()]
).to_local()
self.assertNotEqual(fsdp_redistributed, tp_redistributed)
dist_cp.load(
state_dict=tp_state_dict,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
)
tp_model.load_state_dict(tp_state_dict)
# Check parameters are equal after loading.
tp_state_dict_after_load = tp_model.state_dict()
for fsdp_item, tp_item in zip(fsdp_state_dict.items(), tp_state_dict.items()):
fsdp_k, fsdp_v = fsdp_item
tp_k, tp_v = tp_item
self.assertEqual(fsdp_k, tp_k)
if isinstance(fsdp_v, ShardedTensor) and isinstance(tp_v, DTensor):
fsdp_redistributed = _all_gather_sharded_tensor(fsdp_v)
tp_redistributed = tp_v.redistribute(
device_mesh, placements=[Replicate()]
).to_local()
self.assertEqual(fsdp_redistributed, tp_redistributed)
if __name__ == "__main__":
run_tests()
|