1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
|
# Owner(s): ["oncall: distributed"]
from unittest.mock import patch
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
TensorProperties,
TensorStorageMetadata,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class TestDCPCompatbility(TestCase):
def test_metadata(self) -> None:
# Ensure that all the new fields of all the metadata have the default
# values so that we can always deserialize from a legacy metadata.
try:
tensor = torch.zeros(4, 4)
chunk_meta = ChunkStorageMetadata(
torch.Size((1, 1)),
torch.Size((1, 1)),
)
tensor_meta = TensorStorageMetadata(
properties=TensorProperties.create_from_tensor(tensor),
size=tensor.size(),
chunks=[chunk_meta],
)
b_meta = BytesStorageMetadata()
_ = Metadata(state_dict_metadata={"a": tensor_meta, "b": b_meta})
_ = MetadataIndex(fqn="a.b.c")
except Exception as e:
raise RuntimeError(
"The change may break the BC of distributed checkpoint."
) from e
def test_sharded_tensor_dependency(self) -> None:
# Ensure that we can load the existing DCP checkpoints back even if the
# metadata contain # _shard.sharded_tensor.metadata.
from torch.distributed._shard.sharded_tensor.metadata import (
TensorProperties as stp,
)
with patch("torch.distributed.checkpoint.metadata.TensorProperties", stp):
dcp.save(
{"a": torch.zeros(4, 4)},
dcp.FileSystemWriter("/tmp/dcp_testing"),
)
dcp.load(
{"a": torch.zeros(4, 4)},
dcp.FileSystemReader("/tmp/dcp_testing"),
)
@with_temp_dir
def test_storage_meta(self) -> None:
writer = dcp.FileSystemWriter(self.temp_dir)
dcp.save({"a": torch.zeros(4, 4)}, storage_writer=writer)
reader = dcp.FileSystemReader(self.temp_dir)
storage_meta = reader.read_metadata().storage_meta
self.assertNotEqual(storage_meta, None)
self.assertEqual(str(storage_meta.checkpoint_id), self.temp_dir)
self.assertEqual(storage_meta.save_id, writer.save_id)
self.assertEqual(storage_meta.load_id, reader.load_id)
@with_temp_dir
def test_with_v_2_3(self) -> None:
sd = {
"a": torch.zeros(4, 4),
"dict": {
"dict_a": {"dict_a_1": 1, "dict_a_2": 2},
"dict_b": {"dict_b_1": 1, "dict_b_2": 2},
},
"list": [0, 1, 2, 3, 4, 5],
}
load_sd = {
"a": torch.ones(4, 4),
"dict": {
"dict_a": {"dict_a_1": 2, "dict_a_2": 4},
"dict_b": {"dict_b_1": 2, "dict_b_2": 4},
},
"list": [10, 11, 12, 13, 14, 15],
}
dcp._version._act_like_version = "2_3"
dcp.save(sd, checkpoint_id=self.temp_dir)
dcp._version._act_like_version = None
dcp.load(load_sd, checkpoint_id=self.temp_dir)
self.assertEqual(sd, load_sd)
if __name__ == "__main__":
run_tests()
|