1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
|
# Owner(s): ["oncall: distributed"]
import os
from unittest.mock import patch
import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class MyTestModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
self.net3 = nn.Linear(32, 64)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))
class TestSaveAndLoadAPI(DTensorTestBase):
@property
def world_size(self) -> int:
return 2
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_auto_detect(self):
model = FSDP(MyTestModule().cuda())
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = FSDP(model, device_mesh=device_mesh)
dcp.save(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first"))
dcp.load(model.state_dict(), checkpoint_id=os.path.join(self.temp_dir, "first"))
with patch.object(
dcp.FileSystemReader, "validate_checkpoint_id", return_value=False
):
with patch.object(
dcp.FileSystemWriter, "validate_checkpoint_id", return_value=False
):
dcp.save(
model.state_dict(),
checkpoint_id=os.path.join(self.temp_dir, "second"),
)
dcp.load(
model.state_dict(),
checkpoint_id=os.path.join(self.temp_dir, "second"),
)
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
dcp.save(model.state_dict(), checkpoint_id="abc://abc.abc")
with self.assertRaisesRegex(RuntimeError, "Cannot detect"):
dcp.load(model.state_dict(), checkpoint_id="abc://abc.abc")
@with_comms
@skip_if_lt_x_gpu(2)
def test_assert_same_keys(self):
"""Test the `_assert_same_keys` function."""
model = MyTestModule()
state_dict = model.state_dict()
# Check across ranks; expect true
dcp.utils._assert_same_keys(state_dict)
# Introduces difference; expect false
if self.rank == 0:
state_dict["abc"] = torch.rand(1)
else:
state_dict["def"] = torch.rand(1)
with self.assertRaises(AssertionError):
dcp.utils._assert_same_keys(state_dict)
if __name__ == "__main__":
run_tests()
|