1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
|
import argparse
import logging
import os
from functools import partial
import torch
import torch._dynamo as dynamo
import torch.utils._pytree as pytree
from torch._dynamo.testing import reduce_to_scalar_loss
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.profiler import profile, ProfilerActivity, record_function
try:
from .common import timed
from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
except ImportError:
from common import timed
from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
log = logging.getLogger(__name__)
def torchviz_model(args, model, inputs, rank):
from torchviz import make_dot
outputs = model(*inputs)
loss = reduce_to_scalar_loss(outputs)
parameter_names = dict(model.named_parameters())
dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True)
if rank == 0:
dot.render("torchviz.dot")
def profile_model(args, model, inputs, rank):
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
for i in range(args.repeat):
with record_function("Forward"):
outputs = model(*inputs)
loss = reduce_to_scalar_loss(outputs)
with record_function("Backward"):
loss.backward()
if rank == 0:
prof.export_chrome_trace(args.trace_file)
def run_model(args, model, inputs, key):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
# result_q = []
setup(rank, world_size)
if args.device == "cuda":
# needed for FSDP
torch.cuda.set_device(rank)
dev_rank = f"{args.device}:{rank}"
model = model.to(dev_rank)
def move_tensor(maybe_tensor):
if torch.is_tensor(maybe_tensor):
return maybe_tensor.to(dev_rank)
return maybe_tensor
inputs = pytree.tree_map(move_tensor, inputs)
if args.fsdp:
model = apply_fsdp(
args,
model,
use_checkpointing=args.fsdp_checkpoint,
use_wrap_policy=args.fsdp_wrap,
)
elif args.ddp:
model = DDP(model)
if args.verbose:
print(model)
if args.dynamo:
dynamo.reset()
if args.verbose:
dynamo.config.verbose = True
dynamo.config.log_level = logging.DEBUG
if args.dynamo_no_optimize_ddp:
dynamo.config.optimize_ddp = False
if args.dynamo == "inductor" and args.fsdp:
torch._inductor.config.triton.cudagraphs = False
log.warning("disabling inductor cudagraphs for compatibility with FSDP")
def print_compile(gm, ex):
print(
f"print_compile:\n{str(gm.graph)}\n-----------------------------------------"
)
return gm
dynamo_ctx = dynamo.optimize(
print_compile if args.dynamo == "print" else args.dynamo
)
model = dynamo_ctx(model)
# warmup
_ = timed(model, model_iter_fn, inputs, times=3, return_result=False)
t_total = timed(
model, model_iter_fn, inputs, times=args.repeat, return_result=False
)
if args.torchviz:
torchviz_model(args, model, inputs, rank)
if args.profile:
profile_model(args, model, inputs, rank)
cleanup()
return t_total
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="cuda")
parser.add_argument(
"--dynamo",
default=None,
help="if set to a str, uses dynamo[str] backend. else, eager",
)
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--batch-size", "--batch_size", default=None)
parser.add_argument(
"--torchviz", action="store_true", help="Dump autograd graph with torchviz"
)
parser.add_argument("--profile", action="store_true", help="Run the profiler")
parser.add_argument(
"--trace-file", "--trace_file", default="profile.json", help="Run the profiler"
)
parser.add_argument("--repeat", default=10, help="Repeats for timing run")
parser.add_argument(
"--dynamo-no-optimize-ddp",
"--dynamo_no_optimize_ddp",
action="store_true",
help="Disable dynamo's ddp optimizer (enabled by default)",
)
parser.add_argument(
"--fsdp-checkpoint",
"--fsdp_checkpoint",
action="store_true",
help="Use gradient checkpointing via model-specific policy",
)
parser.add_argument(
"--fsdp-wrap",
"--fsdp_wrap",
action="store_true",
help="Apply fsdp to submodules via model-specific policy",
)
dist_arg = parser.add_mutually_exclusive_group()
dist_arg.add_argument("--ddp", action="store_true")
dist_arg.add_argument("--fsdp", action="store_true")
model_arg = parser.add_mutually_exclusive_group(required=True)
model_arg.add_argument(
"--torchbench-model",
"--torchbench_model",
help="name of torchbench model, e.g. hf_Bert",
)
model_arg.add_argument(
"--toy-model", "--toy_model", action="store_true", help="use toy model instead"
)
args = parser.parse_args()
model_name = args.torchbench_model
if args.toy_model:
model_name = "ToyModel"
model, inputs = get_model(args)
fn = partial(run_model, args, model, inputs)
world_size = os.getenv("WORLD_SIZE", 1)
t_total = fn(f"{model_name}_{world_size}")
print(f"mean latency {t_total / args.repeat} across {args.repeat} runs")
|