File: cse.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (124 lines) | stat: -rw-r--r-- 2,432 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torch.fx as fx
from functorch import make_fx
from torch._functorch.compile_utils import fx_graph_cse
from torch.profiler import profile, ProfilerActivity


def profile_it(f, inp):
    for _ in range(5):
        f(inp)

    itr = 5
    with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
        for _ in range(itr):
            f(inp)

    timing = prof.key_averages()
    cuda_time_total = 0
    for e in timing:
        cuda_time_total = cuda_time_total + e.cuda_time_total
    return cuda_time_total / itr


def profile_function(name, f, inp):
    fx_g = make_fx(f)(inp)

    new_g = fx_graph_cse(fx_g.graph)
    new_g = fx.GraphModule(fx_g, new_g)
    # do not benchmark against the scripted version because script already does some CSE
    # script_f = torch.jit.script(fx_g)
    # script_g = torch.jit.script(new_g)
    # avg_cuda_time_f = profile_it(script_f, inp)
    # avg_cuda_time_g = profile_it(script_g, inp)
    avg_cuda_time_f = profile_it(fx_g, inp)
    avg_cuda_time_g = profile_it(new_g, inp)
    num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)

    print(
        f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}"
    )


g_gpu = torch.Generator(device="cuda")
g_gpu.manual_seed(2147483647)
inp = torch.randn(2**20, device="cuda", generator=g_gpu)


def f1(x):
    return x.cos().cos()


profile_function("f1", f1, inp)


def fsum(x):
    a = x.sum()
    b = x.sum()
    c = x.sum()
    d = x.sum()
    return a + b + c + d


profile_function("fsum", fsum, inp)


def fconcat(x):
    a = torch.cat((x, x))
    b = torch.cat((x, x))
    return a + b


profile_function("fconcat", fconcat, inp)


def fsum2(x):
    a = x.sum()
    for _ in range(30):
        a = a + x.sum()
    return a


profile_function("fsum2", fsum2, inp)


def fsummulti(x):
    a = 0
    for _ in range(3):
        a = a + x.sum()
        a = a * x.sum()
    return a


profile_function("fsummulti", fsummulti, inp)


def fsummulti2(x):
    a = 0
    for _ in range(30):
        a = a + x.sum()
        a = a * x.sum()
    return a


profile_function("fsummulti2", fsummulti2, inp)


def fcos(x):
    a = 0
    for _ in range(3):
        a = a + x.cos()
    return a


profile_function("fcos", fcos, inp)


def fcos2(x):
    a = 0
    for _ in range(30):
        a = a + x.cos()
    return a


profile_function("fcos2", fcos2, inp)