1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
# mypy: allow-untyped-defs
# Owner(s): ["oncall: distributed"]
# pyre-unsafe
import os
import shutil
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
from torch.distributed.checkpoint.state_dict import (
_patch_model_state_dict,
_patch_optimizer_state_dict,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
CHECKPOINT_DIR = f"~/{os.environ['LOGNAME']}/checkpoint"
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
torch.manual_seed(0)
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
self.net3 = nn.Linear(32, 64)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))
def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))
def get_input(self):
return torch.rand(8, 8, device="cuda")
def _make_stateful(model, optim):
_patch_model_state_dict(model)
_patch_optimizer_state_dict(model, optimizers=optim)
def _train(model, optim, train_steps=1):
torch.manual_seed(0)
loss = None
for _ in range(train_steps):
loss = model(model.get_input()).sum()
loss.backward()
optim.step()
optim.zero_grad()
return loss
def _init_model(device, world_size):
device_mesh = init_device_mesh(device, (world_size,))
model = Model().cuda()
model = FSDP(
model,
device_mesh=device_mesh,
use_orig_params=True,
)
optim = torch.optim.Adam(model.parameters(), lr=0.1)
_make_stateful(model, optim)
return model, optim
def run(rank, world_size, device="cuda"):
# Set up world pg
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
model, optim = _init_model(device, world_size)
_train(model, optim, train_steps=2)
dcp.save(
state_dict={"model": model, "optimizer": optim},
checkpoint_id=CHECKPOINT_DIR,
)
# presumably do something else
model, optim = _init_model(device, world_size)
dcp.load(
state_dict={"model": model, "optimizer": optim},
checkpoint_id=CHECKPOINT_DIR,
)
_train(model, optim, train_steps=2)
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running stateful checkpoint example on {world_size} devices.")
shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)
mp.spawn(
run,
args=(world_size,),
nprocs=world_size,
join=True,
)
|