File: distributed.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (177 lines) | stat: -rw-r--r-- 5,628 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import argparse
import logging
import os
from functools import partial

import torch
import torch._dynamo as dynamo
import torch.utils._pytree as pytree
from torch._dynamo.testing import reduce_to_scalar_loss
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.profiler import profile, ProfilerActivity, record_function


try:
    from .common import timed
    from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
except ImportError:
    from common import timed
    from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup

log = logging.getLogger(__name__)


def torchviz_model(args, model, inputs, rank):
    from torchviz import make_dot

    outputs = model(*inputs)
    loss = reduce_to_scalar_loss(outputs)
    parameter_names = dict(model.named_parameters())
    dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True)
    if rank == 0:
        dot.render("torchviz.dot")


def profile_model(args, model, inputs, rank):
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
        for i in range(args.repeat):
            with record_function("Forward"):
                outputs = model(*inputs)
                loss = reduce_to_scalar_loss(outputs)
            with record_function("Backward"):
                loss.backward()
    if rank == 0:
        prof.export_chrome_trace(args.trace_file)


def run_model(args, model, inputs, key):
    rank = int(os.getenv("RANK", 0))
    world_size = int(os.getenv("WORLD_SIZE", 1))
    # result_q = []

    setup(rank, world_size)
    if args.device == "cuda":
        # needed for FSDP
        torch.cuda.set_device(rank)

    dev_rank = f"{args.device}:{rank}"
    model = model.to(dev_rank)

    def move_tensor(maybe_tensor):
        if torch.is_tensor(maybe_tensor):
            return maybe_tensor.to(dev_rank)
        return maybe_tensor

    inputs = pytree.tree_map(move_tensor, inputs)

    if args.fsdp:
        model = apply_fsdp(
            args,
            model,
            use_checkpointing=args.fsdp_checkpoint,
            use_wrap_policy=args.fsdp_wrap,
        )
    elif args.ddp:
        model = DDP(model)

    if args.verbose:
        print(model)

    if args.dynamo:
        dynamo.reset()
        if args.verbose:
            dynamo.config.verbose = True
            dynamo.config.log_level = logging.DEBUG
        if args.dynamo_no_optimize_ddp:
            dynamo.config.optimize_ddp = False
        if args.dynamo == "inductor" and args.fsdp:
            torch._inductor.config.triton.cudagraphs = False
            log.warning("disabling inductor cudagraphs for compatibility with FSDP")

        def print_compile(gm, ex):
            print(
                f"print_compile:\n{str(gm.graph)}\n-----------------------------------------"
            )
            return gm

        dynamo_ctx = dynamo.optimize(
            print_compile if args.dynamo == "print" else args.dynamo
        )
        model = dynamo_ctx(model)

    # warmup
    _ = timed(model, model_iter_fn, inputs, times=3, return_result=False)
    t_total = timed(
        model, model_iter_fn, inputs, times=args.repeat, return_result=False
    )
    if args.torchviz:
        torchviz_model(args, model, inputs, rank)
    if args.profile:
        profile_model(args, model, inputs, rank)

    cleanup()
    return t_total


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default="cuda")
    parser.add_argument(
        "--dynamo",
        default=None,
        help="if set to a str, uses dynamo[str] backend. else, eager",
    )
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--batch-size", "--batch_size", default=None)
    parser.add_argument(
        "--torchviz", action="store_true", help="Dump autograd graph with torchviz"
    )
    parser.add_argument("--profile", action="store_true", help="Run the profiler")
    parser.add_argument(
        "--trace-file", "--trace_file", default="profile.json", help="Run the profiler"
    )
    parser.add_argument("--repeat", default=10, help="Repeats for timing run")
    parser.add_argument(
        "--dynamo-no-optimize-ddp",
        "--dynamo_no_optimize_ddp",
        action="store_true",
        help="Disable dynamo's ddp optimizer (enabled by default)",
    )
    parser.add_argument(
        "--fsdp-checkpoint",
        "--fsdp_checkpoint",
        action="store_true",
        help="Use gradient checkpointing via model-specific policy",
    )
    parser.add_argument(
        "--fsdp-wrap",
        "--fsdp_wrap",
        action="store_true",
        help="Apply fsdp to submodules via model-specific policy",
    )

    dist_arg = parser.add_mutually_exclusive_group()
    dist_arg.add_argument("--ddp", action="store_true")
    dist_arg.add_argument("--fsdp", action="store_true")

    model_arg = parser.add_mutually_exclusive_group(required=True)
    model_arg.add_argument(
        "--torchbench-model",
        "--torchbench_model",
        help="name of torchbench model, e.g. hf_Bert",
    )
    model_arg.add_argument(
        "--toy-model", "--toy_model", action="store_true", help="use toy model instead"
    )
    args = parser.parse_args()

    model_name = args.torchbench_model
    if args.toy_model:
        model_name = "ToyModel"
    model, inputs = get_model(args)

    fn = partial(run_model, args, model, inputs)

    world_size = os.getenv("WORLD_SIZE", 1)
    t_total = fn(f"{model_name}_{world_size}")
    print(f"mean latency {t_total / args.repeat} across {args.repeat} runs")