1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
|
# mypy: allow-untyped-defs
# Owner(s): ["oncall: distributed"]
import os
import shutil
import traceback
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.checkpoint.state_dict import (
_patch_model_state_dict,
_patch_optimizer_state_dict,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor.device_mesh import init_device_mesh
DEVICE = "cuda"
NUM_EPOCHS = 1000
SAVE_PERIOD = 10
FAULT_PERIOD = 25
CHECKPOINT_DIR = f"~/{os.environ.get('LOGNAME', '')}/checkpoint"
class InjectedException(Exception):
pass
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.net1 = nn.Linear(8, 32)
self.net2 = nn.Linear(32, 128)
self.net3 = nn.Linear(128, 64)
self.net4 = nn.Linear(64, 8)
self.net5 = nn.Linear(8, 1)
def forward(self, x):
x = F.relu(self.net1(x))
x = F.relu(self.net2(x))
x = F.relu(self.net3(x))
x = F.relu(self.net4(x))
x = F.sigmoid(self.net5(x))
return x
def _init_model(rank, world_size):
device_mesh = init_device_mesh(DEVICE, (world_size,))
# Create a dummy model and wrap it in FSDP
model = Model().cuda()
device_mesh = init_device_mesh(DEVICE, (world_size,))
model = FSDP(model, device_mesh=device_mesh, use_orig_params=True)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
_patch_model_state_dict(model)
_patch_optimizer_state_dict(model, optimizers=optim)
return model, optim
def _print(msg):
if dist.get_rank() == 0:
print(msg)
def _input():
x = torch.rand(128, 8, device="cuda")
y = torch.zeros(128, 1, device="cuda")
y[torch.sum(x, dim=1) >= 4] = 1.0
return x, y
def run(rank, world_size):
# Set up world pg
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
model, optim = _init_model(rank, world_size)
state_dict = {"model": model, "optim": optim}
loss_calc = torch.nn.BCELoss()
f = None
for epoch in range(NUM_EPOCHS):
try:
torch.manual_seed(epoch)
x, y = _input()
loss = loss_calc(model(x), y)
_print(f"{epoch=} {loss=}")
loss.backward()
optim.step()
optim.zero_grad()
if epoch % SAVE_PERIOD == 0:
if f is not None:
f.result()
f = dcp.state_dict_saver.async_save(
state_dict, checkpoint_id=CHECKPOINT_DIR
)
if FAULT_PERIOD > 0 and epoch % FAULT_PERIOD == 0:
raise InjectedException("Fault injection!")
except InjectedException as e:
dist.barrier()
_print("Trainer encountered exception:")
traceback.print_tb(e.__traceback__)
_print("Reloading model from last checkpoint!")
if f is not None:
f.result()
dcp.load(state_dict)
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running an example of Async Checkpointing on {world_size} devices.")
shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)
mp.spawn(
run,
args=(world_size,),
nprocs=world_size,
join=True,
)
|