1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204
|
# Owner(s): ["oncall: distributed"]
import os
import sys
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch.nn as nn
from torch.distributed._tensor import init_device_mesh
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_state_dict,
set_model_state_dict,
set_state_dict,
StateDictOptions,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
DIM = 500
class PreTrainedModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer1 = nn.Linear(DIM, DIM)
self.layer2 = nn.Linear(DIM, DIM)
self.layer3 = nn.Linear(DIM, DIM)
self.sequential = nn.Sequential(nn.Linear(DIM, DIM), nn.ReLU())
self.module_list = nn.ModuleList([nn.Linear(DIM, DIM), nn.ReLU()])
self.relu = nn.ReLU()
def forward(self, batch):
x = self.relu(self.layer1(batch))
x = self.relu(self.layer2(x))
x = self.relu(self.layer3(x))
x = self.sequential(x)
x = self.module_list[1](self.module_list[0](x))
return x
class FineTuningModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.pretrain = PreTrainedModel()
for p in self.pretrain.parameters():
p.requires_grad = False
self.layer1 = nn.Linear(DIM, DIM)
self.layer2 = nn.Linear(DIM, DIM)
self.layer3 = nn.Linear(DIM, DIM)
self.relu = nn.ReLU()
def forward(self, batch):
x = self.relu(self.pretrain(batch))
x = self.relu(self.layer1(x))
x = self.relu(self.layer2(x))
x = self.relu(self.layer3(x))
return x
class TestFineTuning(DTensorTestBase):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
@property
def backend(self):
return "cpu:gloo,cuda:nccl"
def pretrain(self, pretrain_dir: str) -> None:
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = PreTrainedModel().cuda()
model = FSDP(model, device_mesh=device_mesh)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training
for i in range(3):
batch = torch.rand(32, DIM, device="cuda")
loss = model(batch).sum()
loss.backward()
optim.step()
optim.zero_grad()
# Save state_dict
model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim)
saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict}
dist_cp.save(
state_dict=saved_state_dict,
storage_writer=dist_cp.FileSystemWriter(pretrain_dir),
)
def finetune(self, pretrain_dir: str, finetune_dir: str) -> None:
device_mesh = init_device_mesh(self.device_type, (self.world_size,))
model = FineTuningModel().cuda()
# TODO: make the parallelism more complicated, e.g., using 2D + DDP.
model = FSDP(model, use_orig_params=True, device_mesh=device_mesh)
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
# Simulate that the fine tuning restart after 3 iterations
for i in range(2):
# Load pretrain submodules checkpoint
pretrain_state_dict = get_model_state_dict(
model,
submodules={model.pretrain},
options=StateDictOptions(keep_submodule_prefixes=False),
)
dist_cp.load(
{"model": pretrain_state_dict},
storage_reader=dist_cp.FileSystemReader(pretrain_dir),
)
set_model_state_dict(
model,
model_state_dict={model.pretrain: pretrain_state_dict},
options=StateDictOptions(strict=False),
)
try:
# Load training submodules checkpoint
model_state_dict, optim_state_dict = get_state_dict(
model,
optimizers=optim,
options=StateDictOptions(ignore_frozen_params=True),
)
dist_cp.load_state_dict(
{"model": model_state_dict, "optim": optim_state_dict},
storage_reader=dist_cp.FileSystemReader(pretrain_dir),
)
set_state_dict(
model,
optimizers=optim,
model_state_dict=model_state_dict,
optim_state_dict=optim_state_dict,
options=StateDictOptions(strict=False),
)
except KeyError:
# If this is the first round of the fine tuning, then nothing is saved.
# If this is the restart of the fine tuning, then checkpoint should exit.
self.assertEqual(i, 0)
# Training
for j in range(3):
batch = torch.rand(32, DIM, device="cuda")
loss = model(batch).sum()
loss.backward()
optim.step()
optim.zero_grad()
# Save state_dict
model_state_dict, optim_state_dict = get_state_dict(
model,
optimizers=optim,
options=StateDictOptions(ignore_frozen_params=True),
)
saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict}
dist_cp.save(
state_dict=saved_state_dict,
storage_writer=dist_cp.FileSystemWriter(finetune_dir),
)
@skip_if_lt_x_gpu(4)
@with_comms
@with_temp_dir
def test_fine_tuning(self) -> None:
self.assertTrue(os.path.exists(self.temp_dir))
pretrain_dir = os.path.join(self.temp_dir, "pretrain")
finetune_dir = os.path.join(self.temp_dir, "finetune")
print(pretrain_dir, finetune_dir)
if self.rank == 0:
os.mkdir(pretrain_dir)
os.mkdir(finetune_dir)
dist.barrier()
os.sync()
self.assertTrue(os.path.exists(pretrain_dir))
self.assertTrue(os.path.exists(finetune_dir))
self.pretrain(pretrain_dir)
self.finetune(pretrain_dir, finetune_dir)
if __name__ == "__main__":
run_tests()
|