File: bench_ops.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (105 lines) | stat: -rw-r--r-- 2,677 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import timeit
import torch
import torch.nn.functional as F

torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._debug_set_fusion_group_inlining(False)
torch.set_num_threads(1)


def hardswish(x):
    return x * torch.clamp(x + 3.0, 0.0, 6.0) / 6.0


unary_ops = [
    hardswish,
    torch._C._nn.hardswish,
    torch.sigmoid,
    torch.reciprocal,
    torch.neg,
    torch.relu,
    torch.isnan,
    torch.log,
    torch.log10,
    torch.log1p,
    torch.log2,
    torch.exp,
    torch.expm1,
    torch.erf,
    torch.erfc,
    torch.cos,
    torch.sin,
    torch.tan,
    torch.acos,
    torch.asin,
    torch.cosh,
    torch.sinh,
    torch.atan,
    torch.tanh,
    torch.sqrt,
    torch.rsqrt,
    torch.abs,
    torch.ceil,
    torch.floor,
    torch.round,
    torch.trunc,
    torch.lgamma,
]

print("{:20s} {:>10s} {:>10s} {:>10s}".format("op", "eager", "nnc", "speedup"))

for op in unary_ops:
    x = torch.rand((1024, 1024))
    traced = torch.jit.trace(lambda x: op(x), (x))

    # Warmup.
    warmup_iters = 8
    for _ in range(warmup_iters):
        op(x)
        traced(x)

    # Validate result.
    torch.testing.assert_close(op(x), traced(x))

    # Benchmark.
    bench_iters = 100
    teager = timeit.timeit(stmt="op(x)", globals=globals(), number=bench_iters)
    tjit = timeit.timeit(stmt="traced(x)", globals=globals(), number=bench_iters)
    print(f"{op.__name__:20s} {teager:10.3f} {tjit:10.3f} {teager/tjit:10.2f}")

def test_batch_norm():
    op = F.batch_norm
    print("{:20s} {:20s} {:>10s} {:>10s} {:>10s}".format("op", "shape", "eager", "nnc", "speedup"))
    batch_norm_shapes = [
        [1, 64, 112, 112],
        [1, 256, 14, 14],
        [1, 128, 28, 28],
        [1, 64, 56, 56],
        [1, 512, 7, 7],
        [5, 64, 112, 112],
        [5, 256, 14, 14],
        [5, 128, 28, 28],
        [5, 64, 56, 56],
        [5, 512, 7, 7]]
    for n, c, h, w in batch_norm_shapes:
        x = torch.rand((n, c, h, w))
        y = torch.rand((c))
        z = torch.rand((c))
        traced = torch.jit.trace(lambda x, y, z: op(x, y, z), (x, y, z))

        # Warmup.
        warmup_iters = 8
        for _ in range(warmup_iters):
            op(x, y, z)
            traced(x, y, z)

        # Validate result.
        torch.testing.assert_close(op(x, y, z), traced(x, y, z))

        # Benchmark.
        bench_iters = 100
        teager = timeit.timeit(stmt="op(x, y, z)", globals=locals(), number=bench_iters)
        tjit = timeit.timeit(stmt="traced(x, y, z)", globals=locals(), number=bench_iters)
        print(f"{op.__name__:20s} ({n:>3d}, {c:>3d}, {h:>3d}, {w:>3d}) {teager:10.3f} {tjit:10.3f} {teager/tjit:10.2f}")

test_batch_norm()