1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
|
# Owner(s): ["oncall: distributed"]
import os
from unittest.mock import patch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class MyTestModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
self.net3 = nn.Linear(32, 64)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))
class TestSaveAndLoadAPI(DTensorTestBase):
@property
def world_size(self) -> int:
return 2
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_auto_detect(self):
model = FSDP(MyTestModule().cuda())
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = FSDP(model, device_mesh=device_mesh)
dcp.save(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first"))
sd = dcp.load(
model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first")
)
with patch.object(
dcp.FileSystemReader, "validate_checkpoint_id", return_value=False
) as m1:
with patch.object(
dcp.FileSystemWriter, "validate_checkpoint_id", return_value=False
) as m2:
dcp.save(
model.state_dict(),
checkpoint_id=os.path.join(self.temp_dir, "second"),
)
sd = dcp.load(
model.state_dict(),
checkpoint_id=os.path.join(self.temp_dir, "second"),
)
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
dcp.save(model.state_dict(), checkpoint_id="abc://abc.abc")
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
sd = dcp.load(model.state_dict(), checkpoint_id="abc://abc.abc")
if __name__ == "__main__":
run_tests()
|