1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
|
# Owner(s): ["oncall: distributed"]
import os
import sys
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
DIM = 500
class PipelineModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer1 = nn.Linear(DIM, DIM)
self.layer2 = nn.Linear(DIM, DIM)
self.layer3 = nn.Linear(DIM, DIM)
self.layer4 = nn.Linear(DIM, DIM)
self.relu = nn.ReLU()
def forward(self, batch):
x = self.relu(self.layer1(batch))
x = self.relu(self.layer2(x))
x = self.relu(self.layer3(x))
x = self.relu(self.layer4(x))
return x
class TestPipeline(FSDPTest):
@property
def world_size(self) -> int:
return min(4, torch.cuda.device_count())
def save_with_pipeline(self, pipeline_dir: str) -> None:
with torch.device("meta"):
model = PipelineModel()
pipeline_modules = [model.layer1, model.layer2, model.layer3, model.layer4]
# Materialize the model
submodule = pipeline_modules[self.rank]
submodule.to_empty(device=torch.device("cuda"))
# submodule.reset_parameters()
optim = torch.optim.Adam(submodule.parameters(), lr=1e-3)
# Ignore the training as we don't have a real pipeline parallelism.
# Save state_dict
model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim)
saved_state_dict = {"model": model_state_dict, "optim": optim_state_dict}
dcp.save(
state_dict=saved_state_dict,
storage_writer=dcp.FileSystemWriter(pipeline_dir),
)
def load_with_fsdp(self, pipeline_dir: str) -> None:
model = FSDP(PipelineModel().cuda())
optim = torch.optim.Adam(model.parameters(), lr=1e-3)
# Load the checkpoint
model_state_dict, optim_state_dict = get_state_dict(model, optimizers=optim)
dcp.load(
{"model": model_state_dict, "optim": optim_state_dict},
storage_reader=dcp.FileSystemReader(pipeline_dir),
)
set_state_dict(
model,
optimizers=optim,
model_state_dict=model_state_dict,
optim_state_dict=optim_state_dict,
)
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_pipeline(self) -> None:
self.assertTrue(os.path.exists(self.temp_dir))
pipeline_dir = os.path.join(self.temp_dir, "pipeline")
if self.rank == 0:
os.mkdir(pipeline_dir)
os.sync()
dist.barrier()
self.assertTrue(os.path.exists(pipeline_dir))
self.save_with_pipeline(pipeline_dir)
self.load_with_fsdp(pipeline_dir)
if __name__ == "__main__":
run_tests()
|