1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
|
# Owner(s): ["oncall: distributed"]
from copy import deepcopy
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed._tensor import init_device_mesh
from torch.distributed.checkpoint.default_planner import (
DefaultLoadPlanner,
DefaultSavePlanner,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
MLPModule,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class UnevenShardedModel(torch.nn.Module):
def __init__(self, device):
super().__init__()
torch.manual_seed(5)
self.net1 = torch.nn.Linear(5, 10, device=device)
self.relu = torch.nn.ReLU()
self.net2 = torch.nn.Linear(10, 15, device=device)
self.net3 = torch.nn.Linear(15, 1, device=device)
def forward(self, x):
return self.net3(self.net2(self.relu(self.net1(x))))
class TestTpCheckpoint(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(2)
@with_temp_dir
def test_tp_checkpoint(self):
CHECKPOINT_DIR = self.temp_dir
mesh_shpe = (self.world_size,)
tp_mesh = init_device_mesh(self.device_type, mesh_shpe)
# create model and move it to GPU with id rank
model = MLPModule(self.device_type).cuda(self.rank)
# Parallelize the module based on the given Parallel Style.
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model = parallelize_module(model, tp_mesh, parallelize_plan)
optimizer = torch.optim.SGD(model.parameters(), lr=0.25)
original_state_dict = deepcopy(model.state_dict())
dcp.save(
state_dict=original_state_dict,
storage_writer=dcp.FileSystemWriter(CHECKPOINT_DIR),
planner=DefaultSavePlanner(),
)
# Update the parameters so model.state_dict() will be different from original_state_dict.
torch.manual_seed(0)
inp = torch.rand(20, 10).cuda(self.rank)
output = model(inp)
output.sum().backward()
optimizer.step()
state_dict = model.state_dict()
# ensure the current model parameters are different from original_state_dict before loading from checkpoint
for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
self.assertNotEqual(param1.to_local(), param2.to_local())
dcp.load(
state_dict=state_dict,
storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR),
planner=DefaultLoadPlanner(),
)
# now load from checkpoint to check current model parameters are the same as original_state_dict
for param1, param2 in zip(original_state_dict.values(), state_dict.values()):
self.assertEqual(param1.to_local(), param2.to_local())
@with_comms
@skip_if_lt_x_gpu(2)
@with_temp_dir
def test_tp_checkpoint_load_on_meta_device(self):
CHECKPOINT_DIR = self.temp_dir
mesh_shpe = (self.world_size,)
tp_mesh = init_device_mesh(self.device_type, mesh_shpe)
# create model and move it to GPU with id rank
model = UnevenShardedModel(self.device_type).cuda(self.rank)
# Parallelize the module based on the given Parallel Style.
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
"net3": ColwiseParallel(),
}
model = parallelize_module(model, tp_mesh, parallelize_plan=parallelize_plan)
original_state_dict = {
"model": model.state_dict(),
}
dcp.save(
state_dict=original_state_dict,
storage_writer=dcp.FileSystemWriter(CHECKPOINT_DIR),
)
model2 = parallelize_module(
UnevenShardedModel("meta"), tp_mesh, parallelize_plan=parallelize_plan
)
model2_sd_before_load = model2.state_dict()
state_dict_to_load = {"model": model2_sd_before_load}
dcp.load(
state_dict=state_dict_to_load,
storage_reader=dcp.FileSystemReader(CHECKPOINT_DIR),
)
# We need to make sure state_dict_to_load["model"] is the same as state_dict_after_load["model"],
# since we are doing in-place loading.
self.assertTrue(state_dict_to_load["model"] is model2_sd_before_load)
model2.load_state_dict(state_dict_to_load["model"], assign=True)
state_dict_after_load = {"model": model2.state_dict()}
self.assertEqual(
len(original_state_dict["model"]), len(state_dict_to_load["model"])
)
self.assertEqual(
len(original_state_dict["model"]), len(state_dict_after_load["model"])
)
for name, param in original_state_dict["model"].items():
param_to_load = state_dict_to_load["model"][name]
param_after_load = state_dict_after_load["model"][name]
# we need to explicitly check the device is not meta as the assertEqual check
# currently doesn't handle DTensor with meta device.
self.assertTrue(not param_to_load.is_meta)
self.assertTrue(not param_after_load.is_meta)
self.assertEqual(param.to_local(), param_to_load.to_local())
self.assertEqual(param.to_local(), param_after_load.to_local())
if __name__ == "__main__":
run_tests()
|