File: sdp.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (362 lines) | stat: -rw-r--r-- 10,912 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
import argparse
import itertools
import random
import warnings
from dataclasses import dataclass
from pathlib import Path
from pprint import pprint
from typing import List, Optional

import numpy as np
from prettytable import PrettyTable
from tqdm import tqdm

import torch
import torch.utils.benchmark as benchmark
from torch.backends.cuda import sdp_kernel


warnings.filterwarnings("ignore")


@dataclass(frozen=True)
class ExperimentConfig:
    batch_size: int
    num_heads: int
    max_sequence_len: int
    embed_dimension: int
    dtype: torch.dtype
    pad_percentage: Optional[float]
    enable_math: bool
    enable_flash: bool
    enable_mem_efficient: bool
    enable_cudnn: bool

    def get_entries(self) -> List:
        return [
            self.batch_size,
            self.num_heads,
            self.max_sequence_len,
            self.embed_dimension,
            self.dtype,
            self.pad_percentage,
            self.enable_math,
            self.enable_flash,
            self.enable_mem_efficient,
            self.enable_cudnn,
        ]

    @classmethod
    def get_entry_names(cls) -> List[str]:
        return [
            "batch_size",
            "num_heads",
            "max_sequence_len",
            "embed_dimension",
            "dtype",
            "pad_percentage",
            "enable_math",
            "enable_flash",
            "enable_mem_efficient",
            "enable_cudnn",
        ]


@dataclass(frozen=True)
class ExperimentResults:
    nn_mha_time: float
    compiled_nn_mha_time: Optional[float]
    composite_mha_time: float
    compiled_composite_mha_time: Optional[float]

    def get_entries(self) -> List:
        return [
            f"{self.nn_mha_time:2f}",
            f"{self.compiled_nn_mha_time:2f}" if self.compiled_nn_mha_time else None,
            f"{self.composite_mha_time:2f}",
            f"{self.compiled_composite_mha_time:2f}"
            if self.compiled_composite_mha_time
            else None,
        ]

    @classmethod
    def get_entry_names(cls) -> List[str]:
        return [
            "nn_mha_time (\u00b5s)",
            "compiled_nn_mha_time (\u00b5s)",
            "composite_mha_time (\u00b5s)",
            "compiled_composite_mha_time (\u00b5s)",
        ]


@dataclass(frozen=True)
class Experiment:
    config: ExperimentConfig
    results: ExperimentResults

    def get_entries(self) -> List:
        return self.config.get_entries() + self.results.get_entries()


class CompositeMHA(torch.nn.Module):
    def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
        super().__init__()
        self.in_proj_weight = in_proj_weight
        self.in_proj_bias = in_proj_bias
        self.out_proj = out_proj
        self.num_heads = num_heads

    def forward(self, query, key, value, mask):
        if not (query is key and key is value):
            raise NotImplementedError(
                "query, key and value must be the same Tensor for now."
            )
        if mask is not None:
            raise NotImplementedError("mask is currently not supported.")

        query_projected = torch.nn.functional.linear(
            query, self.in_proj_weight, self.in_proj_bias
        )

        batch_size = query_projected.size(0)
        embed_dim = query_projected.size(2)
        head_dim = embed_dim // (self.num_heads * 3)

        query, key, value = query_projected.chunk(3, -1)

        query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)

        # the output of sdp = (batch, num_heads, seq_len, head_dim)
        attn = torch.nn.functional.scaled_dot_product_attention(
            query,
            key,
            value,
            attn_mask=None,
            dropout_p=0.0,
            is_causal=False,
        )

        attn = attn.transpose(1, 2).reshape(batch_size, -1, self.num_heads * head_dim)
        # Match return signature of nn.MHA
        return self.out_proj(attn), None


def build_composite_mha_from_nn_mha(pt):
    assert pt._qkv_same_embed_dim
    in_proj_weight = pt.in_proj_weight
    assert in_proj_weight is not None
    assert pt.batch_first
    return CompositeMHA(pt.num_heads, pt.in_proj_weight, pt.in_proj_bias, pt.out_proj)


def generate_rand_batch(
    batch_size,
    max_sequence_len,
    embed_dimension,
    pad_percentage=None,
    dtype=torch.float16,
    device="cuda",
):
    if not pad_percentage:
        return (
            torch.randn(
                batch_size,
                max_sequence_len,
                embed_dimension,
                dtype=dtype,
                device=device,
            ),
            None,
        )
    # Really slow but should work
    seq_len_list = [
        int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))
        for _ in range(batch_size)
    ]
    # Make random ele max length
    seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len
    # print(f"Theoretical padding: {pad_percentage} actual: {1 - (sum(seq_len_list) / (batch_size * max_sequence_len))}")
    return (
        torch.nested.nested_tensor(
            [
                torch.randn(seq_len, embed_dimension, dtype=dtype, device=device)
                for seq_len in seq_len_list
            ]
        ),
        seq_len_list,
    )


def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6


def assert_close_tensors(tensor_a, tensor_b):
    # First order sanity check. Not a replacement for rigorous tests.
    if tensor_a.is_nested and tensor_b.is_nested:
        for a, b in zip(tensor_a.unbind(), tensor_b.unbind()):
            assert torch.allclose(a, b, atol=1e-2, rtol=1e-2)
    else:
        assert torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3)


def run_single_experiment(config: ExperimentConfig) -> ExperimentResults:
    with sdp_kernel(
        enable_math=config.enable_math,
        enable_flash=config.enable_flash,
        enable_mem_efficient=config.enable_mem_efficient,
        enable_cudnn=config.enable_cudnn,
    ):
        dropout_p = 0.0
        mask = None

        nn_mha = torch.nn.MultiheadAttention(
            embed_dim=config.embed_dimension,
            num_heads=config.num_heads,
            batch_first=True,
            dropout=dropout_p,
        )
        nn_mha = nn_mha.eval().to("cuda", config.dtype)
        composite_mha = build_composite_mha_from_nn_mha(nn_mha)
        qkv, lengths = generate_rand_batch(
            config.batch_size,
            config.max_sequence_len,
            config.embed_dimension,
            config.pad_percentage,
            config.dtype,
        )
        nn_mha_output, _ = nn_mha(qkv, qkv, qkv, mask)
        composite_mha_output, _ = composite_mha(qkv, qkv, qkv, mask)

        # First order sanity check
        assert_close_tensors(nn_mha_output, composite_mha_output)

        nn_mha_time = benchmark_torch_function_in_microseconds(
            nn_mha, qkv, qkv, qkv, mask
        )
        composite_mha_time = benchmark_torch_function_in_microseconds(
            composite_mha, qkv, qkv, qkv, mask
        )

        # TorchDynamo will error on NestedTensors
        if config.pad_percentage is None:
            compiled_nn_mha = torch.compile(nn_mha)
            compiled_composite_mha = torch.compile(composite_mha)

            compiled_nn_mha_time = benchmark_torch_function_in_microseconds(
                compiled_nn_mha, qkv, qkv, qkv, mask
            )

            compiled_composite_mha_time = benchmark_torch_function_in_microseconds(
                compiled_composite_mha,
                qkv,
                qkv,
                qkv,
                mask,
            )
        else:
            compiled_nn_mha_time = None
            compiled_composite_mha_time = None

        results = ExperimentResults(
            nn_mha_time,
            compiled_nn_mha_time,
            composite_mha_time,
            compiled_composite_mha_time,
        )
        return Experiment(config, results)


# Could return generator
def generate_experiments(
    batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
) -> List[ExperimentConfig]:
    configs = []
    for bsz, n_heads, seq_len, embed_dim, dtype, padding in itertools.product(
        batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
    ):
        configs.append(
            ExperimentConfig(
                batch_size=bsz,
                num_heads=n_heads,
                max_sequence_len=seq_len,
                embed_dimension=embed_dim,
                dtype=dtype,
                pad_percentage=padding,
                enable_math=False,
                enable_flash=True,
                enable_mem_efficient=True,
                enable_cudnn=True,
            )
        )
    return configs


def main(save_path: Optional[Path]):
    seed = 123
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Run one timing experiment comparing nn_mha vs composite_mha
    config = ExperimentConfig(
        batch_size=128,
        num_heads=8,
        max_sequence_len=512,
        embed_dimension=512,
        dtype=torch.float16,
        pad_percentage=None,
        enable_math=False,
        enable_flash=True,
        enable_mem_efficient=True,
        enable_cudnn=True,
    )

    experiment = run_single_experiment(config)
    pprint(experiment)

    table = PrettyTable()
    table.float_format = ".3"
    table.field_names = (
        ExperimentConfig.get_entry_names() + ExperimentResults.get_entry_names()
    )

    # Run a bunch of experiments
    batch_sizes = [256]
    num_heads = [32]
    max_seq_lens = [256]
    embed_dims = [512]
    dtypes = [torch.bfloat16, torch.float16, torch.float32]
    pad_percentages = [None, 0.9]

    experiment_configs = generate_experiments(
        batch_sizes, num_heads, max_seq_lens, embed_dims, dtypes, pad_percentages
    )

    experiments: List[Experiment] = []
    for experiment_config in tqdm(experiment_configs):
        experiment = run_single_experiment(experiment_config)
        experiments.append(experiment)
        table.add_row(experiment.get_entries())

    print(table)

    csv_string = table.get_csv_string()
    if save_path is not None:
        with open(save_path, "w") as csvfile:
            csvfile.write(csv_string)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--save-path", "--save_path", type=str, help="Path to save the results"
    )

    args = parser.parse_args()
    save_path = Path(args.save_path) if args.save_path else None
    main(save_path)