File: profiler_bench.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (114 lines) | stat: -rw-r--r-- 3,522 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
import sys
import timeit

import torch
from torch.utils.benchmark import Timer


PARALLEL_TASKS_NUM = 4
INTERNAL_ITER = None


def loop_workload(x):
    for i in range(INTERNAL_ITER):
        x = torch.mm(x, x)
    return x


def parallel_workload(x):
    def parallel_task(x):
        for i in range(int(INTERNAL_ITER / PARALLEL_TASKS_NUM)):
            x = torch.mm(x, x)
        return x

    futs = []
    for i in range(PARALLEL_TASKS_NUM):
        futs.append(torch.jit._fork(parallel_task, x))
    for i in range(PARALLEL_TASKS_NUM):
        torch.jit._wait(futs[i])
    return x


if __name__ == "__main__":
    torch._C._set_graph_executor_optimize(False)
    parser = argparse.ArgumentParser(description="Profiler benchmark")

    parser.add_argument("--with-cuda", "--with_cuda", action="store_true")
    parser.add_argument("--with-stack", "--with_stack", action="store_true")
    parser.add_argument("--use-script", "--use_script", action="store_true")
    parser.add_argument("--use-kineto", "--use_kineto", action="store_true")
    parser.add_argument(
        "--profiling-tensor-size", "--profiling_tensor_size", default=1, type=int
    )
    parser.add_argument("--workload", "--workload", default="loop", type=str)
    parser.add_argument("--internal-iter", "--internal_iter", default=256, type=int)
    parser.add_argument(
        "--timer-min-run-time", "--timer_min_run_time", default=10, type=int
    )
    parser.add_argument("--cuda-only", "--cuda_only", action="store_true")

    args = parser.parse_args()

    if args.with_cuda and not torch.cuda.is_available():
        print("No CUDA available")
        sys.exit()

    print(
        f"Payload: {args.workload}, {args.internal_iter} iterations; timer min. runtime = {args.timer_min_run_time}\n"
    )
    INTERNAL_ITER = args.internal_iter

    for profiling_enabled in [False, True]:
        print(
            "Profiling {}, tensor size {}x{}, use cuda: {}, use kineto: {}, with stacks: {}, use script: {}".format(
                "enabled" if profiling_enabled else "disabled",
                args.profiling_tensor_size,
                args.profiling_tensor_size,
                args.with_cuda,
                args.use_kineto,
                args.with_stack,
                args.use_script,
            )
        )

        input_x = torch.rand(args.profiling_tensor_size, args.profiling_tensor_size)

        if args.with_cuda:
            input_x = input_x.cuda()

        workload = None
        assert args.workload in ["loop", "parallel"]
        if args.workload == "loop":
            workload = loop_workload
        else:
            workload = parallel_workload

        if args.use_script:
            traced_workload = torch.jit.trace(workload, (input_x,))
            workload = traced_workload

        if profiling_enabled:

            def payload():
                x = None
                with torch.autograd.profiler.profile(
                    use_cuda=args.with_cuda,
                    with_stack=args.with_stack,
                    use_kineto=args.use_kineto,
                    use_cpu=not args.cuda_only,
                ):
                    x = workload(input_x)
                return x

        else:

            def payload():
                return workload(input_x)

        t = Timer(
            "payload()",
            globals={"payload": payload},
            timer=timeit.default_timer,
        ).blocked_autorange(min_run_time=args.timer_min_run_time)
        print(t)