File: test_utils.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (132 lines) | stat: -rw-r--r-- 4,711 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Owner(s): ["oncall: distributed"]

import sys

import torch
from torch.distributed._shard.sharded_tensor import (
    Shard,
    ShardedTensor,
    ShardedTensorMetadata,
    ShardMetadata,
)
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from torch.distributed.c10d_logger import _c10d_logger
from torch.distributed.checkpoint.logger import _dcp_logger
from torch.distributed.checkpoint.metadata import MetadataIndex
from torch.distributed.checkpoint.utils import find_state_dict_object
from torch.testing._internal.common_utils import (
    run_tests,
    TEST_WITH_DEV_DBG_ASAN,
    TestCase,
)
from torch.testing._internal.distributed.distributed_utils import with_fake_comms


if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)


def create_sharded_tensor(rank, world_size, shards_per_rank):
    shards_metadata = []
    local_shards = []
    for idx in range(0, world_size * shards_per_rank):
        shard_rank = idx // shards_per_rank
        shard_md = ShardMetadata(
            shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu"
        )
        shards_metadata.append(shard_md)
        if shard_rank == rank:
            shard = Shard.from_tensor_and_offsets(
                torch.rand(*shard_md.shard_sizes),
                shard_offsets=shard_md.shard_offsets,
                rank=rank,
            )
            local_shards.append(shard)

    sharded_tensor_md = ShardedTensorMetadata(
        shards_metadata=shards_metadata,
        size=torch.Size([8 * len(shards_metadata)]),
        tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
    )

    return ShardedTensor._init_from_local_shards_and_global_metadata(
        local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
    )


class TestMedatadaIndex(TestCase):
    def test_init_convert_offset(self):
        a = MetadataIndex("foo", [1, 2])
        b = MetadataIndex("foo", torch.Size([1, 2]))
        self.assertEqual(a, b)

    def test_index_hint_ignored_on_equals(self):
        a = MetadataIndex("foo")
        b = MetadataIndex("foo", index=99)
        self.assertEqual(a, b)

    def test_index_hint_ignored_on_hash(self):
        a = MetadataIndex("foo")
        b = MetadataIndex("foo", index=99)
        self.assertEqual(hash(a), hash(b))

    def test_flat_data(self):
        state_dict = {
            "a": torch.rand(10),
            "b": [1, 2, 3],
        }

        a = find_state_dict_object(state_dict, MetadataIndex("a"))
        self.assertEqual(a, state_dict["a"])
        a = find_state_dict_object(state_dict, MetadataIndex("a", [0]))
        self.assertEqual(a, state_dict["a"])
        a = find_state_dict_object(state_dict, MetadataIndex("a", index=99))
        self.assertEqual(a, state_dict["a"])

        b = find_state_dict_object(state_dict, MetadataIndex("b"))
        self.assertEqual(b, state_dict["b"])
        b = find_state_dict_object(state_dict, MetadataIndex("b", index=1))
        self.assertEqual(b, state_dict["b"])

        with self.assertRaisesRegex(ValueError, "FQN"):
            find_state_dict_object(state_dict, MetadataIndex("c"))
        with self.assertRaisesRegex(ValueError, "ShardedTensor"):
            find_state_dict_object(state_dict, MetadataIndex("b", [1]))

    @with_fake_comms(rank=0, world_size=2)
    def test_sharded_tensor_lookup(self):
        st = create_sharded_tensor(rank=0, world_size=2, shards_per_rank=3)
        state_dict = {"st": st}

        obj = find_state_dict_object(state_dict, MetadataIndex("st", [8]))
        self.assertEqual(obj, st.local_shards()[1].tensor)

        # good hint
        obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=1))
        self.assertEqual(obj, st.local_shards()[1].tensor)

        # bad hint
        obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=2))
        self.assertEqual(obj, st.local_shards()[1].tensor)

        # broken hint
        obj = find_state_dict_object(state_dict, MetadataIndex("st", [8], index=99))
        self.assertEqual(obj, st.local_shards()[1].tensor)

        with self.assertRaisesRegex(ValueError, "no offset was provided"):
            find_state_dict_object(state_dict, MetadataIndex("st"))

        with self.assertRaisesRegex(ValueError, "Could not find shard"):
            find_state_dict_object(state_dict, MetadataIndex("st", [1]))

    def test_dcp_logger(self):
        self.assertTrue(_c10d_logger is not _dcp_logger)
        self.assertEqual(1, len(_c10d_logger.handlers))


if __name__ == "__main__":
    run_tests()