File: bench_mm_fusion.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (123 lines) | stat: -rw-r--r-- 3,146 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# flake8: noqa

import triton
from prettytable import PrettyTable

import torch
import torch._dynamo
import torch._inductor.config
from torch._inductor.runtime.benchmarking import benchmarker


# torch._inductor.config.debug = True
torch._inductor.config.triton.dense_indexing = True
torch.manual_seed(0)


# The flag below controls whether to allow TF32 on matmul.
torch.backends.cuda.matmul.allow_tf32 = True


class Func(object):
    # mm
    @torch._dynamo.optimize("inductor")
    def mm(a, b, bias):
        y = torch.mm(a, b)
        return y

    # mm+bias
    @torch._dynamo.optimize("inductor")
    def mm_add(a, b, bias):
        y = torch.mm(a, b)
        return y + bias

    # relu(mm)
    @torch._dynamo.optimize("inductor")
    def mm_relu(a, b, bias):
        y = torch.mm(a, b)
        return torch.relu(y)

    # relu(mm+bias)
    @torch._dynamo.optimize("inductor")
    def mm_add_relu(a, b, bias):
        y = torch.mm(a, b)
        y += bias
        return torch.relu(y)


def bench(shape, layer_id, p, fusion_types=[""]):
    dtype = torch.float16
    M, K = shape[0]
    _, N = shape[1]
    torch.manual_seed(0)
    # allocate inputs
    a = torch.randn(shape[0], device="cuda", dtype=dtype)
    b = torch.randn(shape[1], device="cuda", dtype=dtype)

    def tflops(ms):
        return M * K * N / ms * 1e-9

    row = [layer_id]
    for fusion_type in fusion_types:
        if fusion_type == "":
            fn_mm = getattr(Func, "mm")
        else:
            fn_mm = getattr(Func, f"mm_{fusion_type}")

        if "add" in fusion_type:
            bias = torch.randn((M, N), dtype=dtype, device="cuda")
        else:
            bias = None

        args = (a, b, bias)

        def fn():
            return fn_mm(*args)

        torch._inductor.config.triton.mm = "aten"
        torch_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
        torch._inductor.config.triton.mm = "triton"
        # reset to force code gen new python code
        torch._dynamo.reset()
        torch._inductor.metrics.reset()
        triton_mm_ms, _, _ = benchmarker.benchmark_gpu(fn)
        assert (
            torch._inductor.metrics.generated_kernel_count == 1
        ), "codegen #kernel != 1"
        row.extend([tflops(torch_mm_ms), tflops(triton_mm_ms)])

    p.add_row(row)


fusion_types = ["", "add", "relu", "add_relu"]
shapes = [
    # alexnet
    ([128, 9216], [9216, 4096]),
    ([128, 4096], [4096, 4096]),
    ([128, 4096], [4096, 1000]),
    # BERT
    ([2048, 768], [768, 768]),
    ([2048, 768], [768, 3072]),
    ([2048, 3072], [3072, 768]),
    # hf_GPT2
    ([1024, 768], [768, 768]),
    ([1024, 768], [768, 3072]),
    ([1024, 3072], [3072, 768]),
    ([1024, 768], [768, 2304]),
]
p = PrettyTable()
field_names = ["layer"]
for fusion_type in fusion_types:
    if fusion_type == "":
        field_names.append("torch mm")
        field_names.append("triton mm")
    else:
        field_names.append(f"torch mm+{fusion_type}")
        field_names.append(f"triton mm+{fusion_type}")

p.field_names = field_names
p.float_format = ".3"
for id, shape in enumerate(shapes):
    bench(shape, id, p, fusion_types)

print(p)