1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
# Owner(s): ["oncall: distributed"]
import sys
import tempfile
import torch
from torch import distributed as dist
from torch.distributed._shard.checkpoint import (
FileSystemReader,
FileSystemWriter,
save_state_dict,
load_state_dict,
)
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
StateDictType,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from torch.distributed.fsdp.wrap import enable_wrap, wrap
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import (
FSDPTest,
SkipModel,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
_DISTRIBUTED_STATE_DICT_IMPLS = {
StateDictType.LOCAL_STATE_DICT,
StateDictType.SHARDED_STATE_DICT,
}
class TestDistributedCheckpoint(FSDPTest):
@property
def world_size(self):
return 2
@skip_if_lt_x_gpu(2)
@parametrize("state_dict_type", _DISTRIBUTED_STATE_DICT_IMPLS)
def test_distributed_checkpoint(self, state_dict_type) -> None:
with enable_wrap(wrapper_cls=FSDP):
torch.manual_seed(100)
model = wrap(SkipModel(double_nest=True))
torch.manual_seed(200)
new_model = wrap(SkipModel(double_nest=True))
with FullyShardedDataParallel.summon_full_params(
model
), FullyShardedDataParallel.summon_full_params(new_model):
params = list(model.parameters())
new_params = list(new_model.parameters())
self.assertNotEqual(params, new_params)
with tempfile.TemporaryDirectory() as path:
paths = [path]
dist.broadcast_object_list(paths)
path = paths[0]
writer = FileSystemWriter(path)
reader = FileSystemReader(path)
with FSDP.state_dict_type(
model, state_dict_type
), FSDP.state_dict_type(new_model, state_dict_type):
state_dict = model.state_dict()
save_state_dict(state_dict, writer)
with FSDP.state_dict_type(
model, state_dict_type
), FSDP.state_dict_type(new_model, state_dict_type):
state_dict = new_model.state_dict()
load_state_dict(state_dict, reader)
new_model.load_state_dict(state_dict)
with FullyShardedDataParallel.summon_full_params(
model
), FullyShardedDataParallel.summon_full_params(new_model):
params = list(model.parameters())
new_params = list(new_model.parameters())
self.assertEqual(params, new_params)
# TODO: add resharding test case.
instantiate_parametrized_tests(TestDistributedCheckpoint)
if __name__ == "__main__":
run_tests()
|