File: test_tp_checkpoint.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (151 lines) | stat: -rw-r--r-- 5,635 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# Owner(s): ["oncall: distributed"]

from copy import deepcopy

import torch
import torch.distributed.checkpoint as dcp
from torch.distributed._tensor import init_device_mesh
from torch.distributed.checkpoint.default_planner import (
    DefaultLoadPlanner,
    DefaultSavePlanner,
)
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    parallelize_module,
    RowwiseParallel,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
    DTensorTestBase,
    MLPModule,
    skip_if_lt_x_gpu,
    with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir


class UnevenShardedModel(torch.nn.Module):
    def __init__(self, device):
        super().__init__()
        torch.manual_seed(5)
        self.net1 = torch.nn.Linear(5, 10, device=device)
        self.relu = torch.nn.ReLU()
        self.net2 = torch.nn.Linear(10, 15, device=device)
        self.net3 = torch.nn.Linear(15, 1, device=device)

    def forward(self, x):
        return self.net3(self.net2(self.relu(self.net1(x))))


class TestTpCheckpoint(DTensorTestBase):
    @with_comms
    @skip_if_lt_x_gpu(2)
    @with_temp_dir
    def test_tp_checkpoint(self):
        CHECKPOINT_DIR = self.temp_dir
        mesh_shpe = (self.world_size,)
        tp_mesh = init_device_mesh(self.device_type, mesh_shpe)

        # create model and move it to GPU with id rank
        model = MLPModule(self.device_type).cuda(self.rank)
        # Parallelize the module based on the given Parallel Style.
        parallelize_plan = {
            "net1": ColwiseParallel(),
            "net2": RowwiseParallel(),
        }
        model = parallelize_module(model, tp_mesh, parallelize_plan)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.25)
        original_state_dict = deepcopy(model.state_dict())

        dcp.save(
            state_dict=original_state_dict,
            storage_writer=dcp.FileSystemWriter(CHECKPOINT_DIR),
            planner=DefaultSavePlanner(),
        )

        # Update the parameters so model.state_dict() will be different from original_state_dict.
        torch.manual_seed(0)
        inp = torch.rand(20, 10).cuda(self.rank)
        output = model(inp)
        output.sum().backward()
        optimizer.step()
        state_dict = model.state_dict()

        # ensure the current model parameters are different from original_state_dict before loading from checkpoint
        for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
            self.assertNotEqual(param1.to_local(), param2.to_local())

        dcp.load(
            state_dict=state_dict,
            storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR),
            planner=DefaultLoadPlanner(),
        )

        # now load from checkpoint to check current model parameters are the same as original_state_dict
        for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
            self.assertEqual(param1.to_local(), param2.to_local())

    @with_comms
    @skip_if_lt_x_gpu(2)
    @with_temp_dir
    def test_tp_checkpoint_load_on_meta_device(self):
        CHECKPOINT_DIR = self.temp_dir
        mesh_shpe = (self.world_size,)
        tp_mesh = init_device_mesh(self.device_type, mesh_shpe)

        # create model and move it to GPU with id rank
        model = UnevenShardedModel(self.device_type).cuda(self.rank)
        # Parallelize the module based on the given Parallel Style.
        parallelize_plan = {
            "net1": ColwiseParallel(),
            "net2": RowwiseParallel(),
            "net3": ColwiseParallel(),
        }
        model = parallelize_module(model, tp_mesh, parallelize_plan=parallelize_plan)
        original_state_dict = {
            "model": model.state_dict(),
        }

        dcp.save(
            state_dict=original_state_dict,
            storage_writer=dcp.FileSystemWriter(CHECKPOINT_DIR),
        )

        model2 = parallelize_module(
            UnevenShardedModel("meta"), tp_mesh, parallelize_plan=parallelize_plan
        )
        model2_sd_before_load = model2.state_dict()
        state_dict_to_load = {"model": model2_sd_before_load}

        dcp.load(
            state_dict=state_dict_to_load,
            storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR),
        )
        # We need to make sure state_dict_to_load["model"] is the same as state_dict_after_load["model"],
        # since we are doing in-place loading.
        self.assertTrue(state_dict_to_load["model"] is model2_sd_before_load)

        model2.load_state_dict(state_dict_to_load["model"], assign=True)
        state_dict_after_load = {"model": model2.state_dict()}

        self.assertEqual(
            len(original_state_dict["model"]), len(state_dict_to_load["model"])
        )
        self.assertEqual(
            len(original_state_dict["model"]), len(state_dict_after_load["model"])
        )

        for name, param in original_state_dict["model"].items():
            param_to_load = state_dict_to_load["model"][name]
            param_after_load = state_dict_after_load["model"][name]

            # we need to explicitly check the device is not meta as the assertEqual check
            # currently doesn't handle DTensor with meta device.
            self.assertTrue(not param_to_load.is_meta)
            self.assertTrue(not param_after_load.is_meta)
            self.assertEqual(param.to_local(), param_to_load.to_local())
            self.assertEqual(param.to_local(), param_after_load.to_local())


if __name__ == "__main__":
    run_tests()