File: tensor_layout_mini_benchmark.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (67 lines) | stat: -rw-r--r-- 1,671 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from torch._inductor import ir
from torch._inductor.runtime.benchmarking import benchmarker


def to_channels_last(x):
    assert x.dim() == 4

    # NCHW -> NHWC
    stride_order = [3, 0, 2, 1]
    y = x.clone().as_strided(
        x.shape,
        ir.FlexibleLayout.stride_ordered(x.shape, stride_order),
    )
    y.copy_(x)
    assert torch.allclose(x, y)
    return y


def bench_conv(with_stack=True):
    x = torch.rand(256, 3, 224, 224).cuda()
    weight = torch.rand(64, 3, 7, 7).cuda()

    x_chan = to_channels_last(x)
    weight_chan = to_channels_last(weight)
    kwargs = {
        "stride": [2, 2],
        "padding": [3, 3],
        "dilation": [1, 1],
        "transposed": False,
        "output_padding": [0, 0],
        "groups": 1,
    }

    def baseline_fn():
        return torch.convolution(x, weight, bias=None, **kwargs)

    def test_fn():
        return torch.convolution(x_chan, weight_chan, bias=None, **kwargs)

    # warmup
    baseline_fn()
    test_fn()

    torch.cuda.synchronize()
    with torch.profiler.profile(with_stack=with_stack) as p:
        baseline_out = baseline_fn()
        test_out = test_fn()
        torch.cuda.synchronize()

    p.export_chrome_trace("/tmp/chrome.json")
    assert torch.allclose(baseline_out, test_out, atol=1e-3, rtol=1e-3), (
        baseline_out[0][0][0][:32],
        test_out[0][0][0][:32],
    )

    baseline_ms = benchmarker.benchmark_gpu(baseline_fn, rep=40)
    test_ms = benchmarker.benchmark_gpu(test_fn, rep=40)
    print(f"baseline {baseline_ms} test {test_ms} speedup {baseline_ms / test_ms:.3f}x")


def main():
    bench_conv()


if __name__ == "__main__":
    main()