File: test_fsdp_optim_state.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (135 lines) | stat: -rw-r--r-- 4,827 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Owner(s): ["oncall: distributed"]

import torch
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir


class FsdpOptimStateCheckpoint(DTensorTestBase):
    def _create_model(self):
        # make weight tensor dim_0 as large as the world size for scaling test
        layer1_weight_dim = self.world_size
        layer2_weight_dim = self.world_size * 2
        layer3_weight_dim = self.world_size * 3

        class TestDummyModel(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.net1 = nn.Sequential(nn.Linear(8, layer1_weight_dim), nn.ReLU())
                self.net2 = nn.Sequential(
                    nn.Linear(layer1_weight_dim, layer2_weight_dim), nn.ReLU()
                )
                self.net3 = nn.Sequential(
                    nn.Linear(layer2_weight_dim, layer3_weight_dim), nn.ReLU()
                )

            def forward(self, x):
                return self.net3(self.net2(self.net1(x)))

            def get_input(self):
                return torch.rand(8, 8, device="cuda")

        model = TestDummyModel().cuda()
        return model

    @property
    def backend(self):
        return "cpu:gloo,cuda:nccl"

    @with_comms
    @skip_if_lt_x_gpu(2)
    @with_temp_dir
    @parametrize("pass_planner", [True, False])
    def test_load_sharded_optimizer_state_dict(self, pass_planner) -> None:
        CHECKPOINT_DIR = self.temp_dir
        planner = dcp.DefaultLoadPlanner() if pass_planner else None

        model = self._create_model()
        model = FSDP(model)
        optim = torch.optim.Adam(model.parameters(), lr=0.1)

        # step ahead to initialize the optimizer
        model(model.get_input()).sum().backward()
        optim.step()

        FSDP.set_state_dict_type(
            model,
            StateDictType.SHARDED_STATE_DICT,
        )
        optim_osd = FSDP.optim_state_dict(model, optim)

        state_dict = {
            "model": model.state_dict(),
            "optim": optim_osd,
        }
        dcp.save(
            state_dict=state_dict,
            storage_writer=dcp.FileSystemWriter(CHECKPOINT_DIR),
        )

        # now load the model and ensure the values are the same
        model_2 = self._create_model()
        model_2 = FSDP(model_2)
        optim_2 = torch.optim.Adam(model_2.parameters(), lr=0.1)

        FSDP.set_state_dict_type(
            model_2,
            StateDictType.SHARDED_STATE_DICT,
        )
        # Adam lazily creates its state
        self.assertEqual(0, len(optim_2.state))

        state_dict = {
            "model": model_2.state_dict(),
            # cannot load the optimizer together with the model
        }
        dcp.load(
            state_dict=state_dict,
            storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR),
        )
        model_2.load_state_dict(state_dict["model"])

        optim_state = load_sharded_optimizer_state_dict(
            model_state_dict=state_dict["model"],
            optimizer_key="optim",
            storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR),
            planner=planner,
        )
        flattened_osd = FSDP.optim_state_dict_to_load(
            model_2, optim_2, optim_state["optim"]
        )
        optim_2.load_state_dict(flattened_osd)
        osd_after_load = FSDP.optim_state_dict(model_2, optim_2)

        # Compare optim_state_dict prior to save and after load
        before_optim_state = optim_osd["state"]
        after_optim_state = osd_after_load["state"]
        self.assertEqual(len(before_optim_state), len(after_optim_state))
        for fqn, states in before_optim_state.items():
            for state_name, state in states.items():
                state2 = after_optim_state.get(fqn).get(state_name)
                if isinstance(state, ShardedTensor):
                    self.assertTrue(isinstance(state2, ShardedTensor))
                    self.assertTrue(torch.allclose(state, state2))
                else:
                    self.assertEqual(state, state2)


instantiate_parametrized_tests(FsdpOptimStateCheckpoint)
if __name__ == "__main__":
    run_tests()