File: benchmark.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (190 lines) | stat: -rw-r--r-- 5,143 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import torch
from functorch.compile import memory_efficient_fusion
import benchmark_helper


device = "cuda"
dtype = torch.float16

# LightSeq pattern 1
class DropoutResBias:
    @staticmethod
    def fn(input, bias, residual):
        a = torch.add(input, bias)
        b = torch.nn.functional.dropout(a, p=0.7, training=True)
        c = b + residual
        return c

    @staticmethod
    def args():
        batch_size, seq_len, hidden_size = 32, 196, 1024
        input = torch.randn(
            batch_size,
            seq_len,
            hidden_size,
            requires_grad=True,
            device=device,
            dtype=dtype,
        )
        bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype)
        residual = torch.randn(
            batch_size,
            seq_len,
            hidden_size,
            requires_grad=False,
            device=device,
            dtype=dtype,
        )
        args = (input, bias, residual)
        return args


class DropoutResBiasScalar:
    @staticmethod
    def fn(input, bias, residual, p: float):
        a = torch.add(input, bias)
        b = torch.nn.functional.dropout(a, p, training=True)
        c = b + residual
        return c

    @staticmethod
    def args():
        batch_size, seq_len, hidden_size = 32, 196, 1024
        input = torch.randn(
            batch_size,
            seq_len,
            hidden_size,
            requires_grad=True,
            device=device,
            dtype=dtype,
        )
        bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype)
        residual = torch.randn(
            batch_size,
            seq_len,
            hidden_size,
            requires_grad=False,
            device=device,
            dtype=dtype,
        )
        args = (input, bias, residual, 0.7)
        return args



# LightSeq pattern 2
class BiasReluDropout:
    @staticmethod
    def fn(input, bias):
        a = torch.add(input, bias)
        b = torch.nn.functional.relu(a)
        c = torch.nn.functional.dropout(b, p=0.6, training=True)
        return c

    @staticmethod
    def args():
        batch_size = 32
        seq_len = 196
        intermediate_size = 4096
        input = torch.randn(
            batch_size,
            seq_len,
            intermediate_size,
            requires_grad=True,
            device=device,
            dtype=dtype,
        )
        bias = torch.randn(
            intermediate_size, requires_grad=True, device=device, dtype=dtype
        )
        args = (input, bias)
        return args


class BiasDropoutResLayerNorm:
    @staticmethod
    def fn(input, bias, residual):
        hidden_size = 1024
        a = torch.add(input, bias)
        b = torch.nn.functional.dropout(a, p=0.7, training=True)
        c = b + residual
        d = torch.nn.functional.layer_norm(c, normalized_shape=(hidden_size,))
        return d

    @staticmethod
    def args():
        batch_size = 32
        seq_len = 196
        hidden_size = 1024

        input = torch.randn(
            batch_size,
            seq_len,
            hidden_size,
            requires_grad=True,
            device=device,
            dtype=dtype,
        )
        bias = torch.randn(hidden_size, requires_grad=True, device=device, dtype=dtype)
        residual = torch.randn(
            batch_size,
            seq_len,
            hidden_size,
            requires_grad=False,
            device=device,
            dtype=dtype,
        )
        args = (input, bias, residual)
        return args


class LayerNormSigmoid:
    @staticmethod
    def fn(inp):
        hidden_size = 512
        a = torch.nn.functional.layer_norm(inp, normalized_shape=(hidden_size,))
        b = torch.sigmoid(a)
        return b

    @staticmethod
    def args():
        batch_size = 8192
        hidden_size = 512
        inp = torch.randn(
            batch_size, hidden_size, requires_grad=True, device=device, dtype=dtype
        )
        args = (inp,)
        return args


for cl in [DropoutResBias, BiasReluDropout, DropoutResBiasScalar, BiasDropoutResLayerNorm, LayerNormSigmoid]:
    # Clear the compile cache

    # Get the function and inputs
    obj = cl()
    fn = obj.fn
    args = obj.args()

    # Find the static args
    static_argnums = []
    for idx, arg in enumerate(args):
        if not isinstance(arg, torch.Tensor):
            static_argnums.append(idx)

    # Get the optimized function
    opt_fn = memory_efficient_fusion(fn, static_argnums)

    # Profile cuda kernels
    benchmark_helper.profile_cuda_kernels(fn, args, "Eager")
    with torch.jit.fuser("fuser2"):
        benchmark_helper.profile_cuda_kernels(opt_fn, args, "AOTAutograd")

    # Time it with Torch Timer
    benchmark_helper.time_with_torch_timer(fn, args, "Eager")
    with torch.jit.fuser("fuser2"):
        benchmark_helper.time_with_torch_timer(opt_fn, args, "AOTAutograd")

    # Time it with manual Timer
    benchmark_helper.time_with_manual_timer(fn, args, "Eager")
    with torch.jit.fuser("fuser2"):
        benchmark_helper.time_with_manual_timer(opt_fn, args, "AOTAutograd")