1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
|
# mypy: allow-untyped-defs
# Copyright (c) Meta Platforms, Inc. and affiliates
"""
The following example demonstrates how to use Pytorch Distributed Checkpoint to save a FSDP model.
This is the current recommended way to checkpoint FSDP.
torch.save() and torch.load() is not recommended when checkpointing sharded models.
"""
import os
import shutil
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch.multiprocessing as mp
from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
CHECKPOINT_DIR = f"/scratch/{os.environ['LOGNAME']}/checkpoint"
def opt_at(opt, idx):
return list(opt.state.values())[idx]
def init_model():
model = FSDP(torch.nn.Linear(4, 4).cuda(dist.get_rank()))
optim = torch.optim.Adam(model.parameters(), lr=0.1)
model(torch.rand(4, 4)).sum().backward()
optim.step()
return model, optim
def print_params(stage, model_1, model_2, optim_1, optim_2):
with FSDP.summon_full_params(model_1):
with FSDP.summon_full_params(model_2):
print(
f"{stage} --- rank: {dist.get_rank()}\n"
f"model.weight: {model_1.weight}\n"
f"model_2.weight:{model_2.weight}\n"
f"model.bias: {model_1.bias}\n"
f"model_2.bias: {model_2.bias}\n"
)
print(
f"{stage} --- rank: {dist.get_rank()}\n"
f"optim exp_avg:{opt_at(optim_1, 0)['exp_avg']}\n"
f"optim_2 exp_avg:{opt_at(optim_2, 0)['exp_avg']}\n"
f"optim exp_avg_sq:{opt_at(optim_1, 0)['exp_avg_sq']}\n"
f"optim_2 exp_avg_sq:{opt_at(optim_2, 0)['exp_avg_sq']}\n"
)
def run_fsdp_checkpoint_example(rank, world_size):
# Set up world pg
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
# Initialize the process group
dist.init_process_group("cpu:gloo,cuda:nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
# Create a model
model_1, optim_1 = init_model()
# Save the model to CHECKPOINT_DIR
with FSDP.state_dict_type(model_1, StateDictType.SHARDED_STATE_DICT):
state_dict = {
"model": model_1.state_dict(),
"optim": FSDP.optim_state_dict(model_1, optim_1),
}
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
)
# Create a second model
model_2, optim_2 = init_model()
# Print the model parameters for both models.
# Before loading, the parameters should be different.
print_params("Before loading", model_1, model_2, optim_1, optim_2)
# Load model_2 with parameters saved in CHECKPOINT_DIR
with FSDP.state_dict_type(model_2, StateDictType.SHARDED_STATE_DICT):
state_dict = {
"model": model_2.state_dict(),
# cannot load the optimizer state_dict together with the model state_dict
}
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
)
model_2.load_state_dict(state_dict["model"])
optim_state = load_sharded_optimizer_state_dict(
model_state_dict=state_dict["model"],
optimizer_key="optim",
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
)
flattened_osd = FSDP.optim_state_dict_to_load(
model_2, optim_2, optim_state["optim"]
)
optim_2.load_state_dict(flattened_osd)
# Print the model parameters for both models.
# After loading, the parameters should be the same.
print_params("After loading", model_1, model_2, optim_1, optim_2)
# Shut down world pg
dist.destroy_process_group()
if __name__ == "__main__":
world_size = torch.cuda.device_count()
print(f"Running fsdp checkpoint example on {world_size} devices.")
shutil.rmtree(CHECKPOINT_DIR, ignore_errors=True)
mp.spawn(
run_fsdp_checkpoint_example,
args=(world_size,),
nprocs=world_size,
join=True,
)
|