File: nested_bmm_bench.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (66 lines) | stat: -rw-r--r-- 1,823 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import random

import torch


def bench(nt_a, nt_b, niter):
    # Warmup
    nt_a.bmm(nt_b)

    torch.cuda.synchronize()
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    start_event.record()
    for iter in range(niter):
        nt_a.bmm(nt_b)
    end_event.record()
    torch.cuda.synchronize()
    runtime = (start_event.elapsed_time(end_event)) / niter
    return runtime


def sweep_n(niter, dtype):
    for ntensor in [4, 8, 16, 32, 64, 128, 256]:
        tensors = [torch.randn(256, random.randint(100, 200)) for t in range(ntensor)]
        nt_a = torch.nested.nested_tensor(
            tensors,
            dtype=dtype,
            device="cuda",
        )
        nt_b = torch.nested.nested_tensor(
            [t.t() for t in tensors],
            dtype=dtype,
            device="cuda",
        )
        runtime = bench(nt_a, nt_b, niter)
        nt_a_size = torch.ops.aten._nested_tensor_size(nt_a)
        lengths = nt_a_size[:, 1]
        print(
            ",".join(
                map(
                    str,
                    [
                        ntensor,
                        dtype,
                        lengths.min().item(),
                        lengths.float().mean().item(),
                        lengths.max().item(),
                        runtime,
                    ],
                )
            )
        )


if __name__ == "__main__":
    random.seed(123)
    parser = argparse.ArgumentParser(description="Nested Tensor BMM Benchmark")
    parser.add_argument("--niter", default="10", type=int)

    args = parser.parse_args()
    niter = args.niter

    print("ntensor,dtype,min_length,mean_length,max_length,runtime")
    sweep_n(niter, torch.float32)
    sweep_n(niter, torch.float16)