File: cache_hit_microbenchmarks.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (49 lines) | stat: -rw-r--r-- 1,214 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import timeit

import torch.fx
from torch._dynamo.utils import counters
from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache


N = 10000
K = 100


def huge_graph(x):
    for _ in range(N):
        x = x.sin()
    return x


def main():
    torch._inductor.config.fx_graph_cache = True
    torch._inductor.config.fx_graph_remote_cache = False

    with fresh_inductor_cache():
        a = torch.randn(4).cuda()
        compiled_fn = torch.compile(huge_graph, backend="inductor")

        # write to cache
        compiled_fn(a)
        assert counters["inductor"]["fxgraph_cache_miss"] == 1

        def setup():
            torch._dynamo.reset()
            clear_inductor_caches()
            for m in torch._inductor.codecache.PyCodeCache.cache.values():
                os.remove(m.__file__)
            counters.clear()

        def fn():
            result = compiled_fn(a)
            assert counters["inductor"]["fxgraph_cache_miss"] == 0
            assert counters["inductor"]["fxgraph_cache_hit"] == 1
            return result

        t = min(timeit.repeat(fn, setup=setup, number=K, repeat=3))
        print(f"took {t:.1f}s")


if __name__ == "__main__":
    main()