1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
|
import argparse
import functools
import importlib
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch._dynamo.testing import reduce_to_scalar_loss
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
checkpoint_wrapper,
CheckpointImpl,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
try:
from .torchbench import setup_torchbench_cwd
except ImportError:
from torchbench import setup_torchbench_cwd
from transformers.models.bert.modeling_bert import BertLayer, BertLMPredictionHead
from transformers.models.t5.modeling_t5 import T5Block
def setup(rank, world_size):
os.environ["MASTER_ADDR"] = os.getenv("MASTER_ADDR", "localhost")
os.environ["MASTER_PORT"] = os.getenv("MASTER_PORT", "12355")
os.environ["RANK"] = os.getenv("RANK", "0")
os.environ["WORLD_SIZE"] = os.getenv("WORLD_SIZE", "1")
dist.init_process_group("nccl")
def cleanup():
dist.destroy_process_group()
class CustomLinear(torch.nn.Module):
def __init__(self, a, b):
super().__init__()
self.weight = nn.Parameter(torch.randn(a, b))
def forward(self, x):
return torch.mm(x, self.weight)
class MyModule(torch.nn.Module):
def __init__(self, a, b):
super().__init__()
self.net = nn.Sequential(
nn.Linear(a, b),
nn.ReLU(),
)
def forward(self, x):
return self.net(x)
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
*[nn.Linear(10, 10000), nn.ReLU()]
+ [nn.Linear(10000, 10000), nn.ReLU()]
+ [MyModule(10000, 10000)]
+ [MyModule(10000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [MyModule(1000, 1000)]
+ [nn.Linear(1000, 5)]
)
def forward(self, x):
return self.net(x)
def model_iter_fn(model, example_inputs, collect_outputs=False):
outputs = model(*example_inputs)
loss = reduce_to_scalar_loss(outputs)
loss.backward()
if collect_outputs:
return outputs
def get_model(args):
if args.torchbench_model:
setup_torchbench_cwd()
module = importlib.import_module(
f"torchbenchmark.models.{args.torchbench_model}"
)
benchmark_cls = getattr(module, "Model", None)
bm = benchmark_cls(test="train", device=args.device, batch_size=args.batch_size)
model, inputs = bm.get_module()
elif args.toy_model:
model = ToyModel()
inputs = (torch.randn(20, 10),)
else:
raise argparse.ArgumentError(
args.torchbench_model, message="Must specify a model"
)
return model, inputs
def fsdp_checkpointing_base(model, blocks):
"""apply activation checkpointing to model
returns None as model is updated directly
"""
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
def check_fn(submodule):
return isinstance(submodule, blocks)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)
MODEL_FSDP_WRAP = {
"toy_model": (MyModule,),
"hf_Bert": (BertLayer, BertLMPredictionHead),
"hf_T5": (T5Block,),
}
def apply_fsdp(args, model, use_checkpointing=False, use_wrap_policy=True):
wrap_policy = None
blocks = MODEL_FSDP_WRAP[
"toy_model" if model.__class__ is ToyModel else args.torchbench_model
]
if use_wrap_policy:
wrap_policy = ModuleWrapPolicy(blocks)
model = FSDP(model, auto_wrap_policy=wrap_policy, use_orig_params=True)
if use_checkpointing:
fsdp_checkpointing_base(model, blocks)
return model
|