File: eager_fusion.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (58 lines) | stat: -rw-r--r-- 1,168 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import time

import torch
import torch.utils
from functorch.compile import aot_function, tvm_compile


a = torch.randn(2000, 1, 4, requires_grad=True)
b = torch.randn(1, 2000, 4)


def f(a):
    return (a * b).sum(dim=0)


fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops")
bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops")
compiled_f = aot_function(f, fw_compiler, bw_compiler)

# fw_compiler = lambda x, _: x
# bw_compiler = lambda x, _: x
iters = 10
out = compiled_f(a)
out.sum().backward()


def bench(func):
    begin = time.time()
    for _ in range(iters):
        out = func(a).sin()
        out.sum().backward()
        a.grad = None
    print(time.time() - begin)


def bench_jax():
    import jax
    import jax.numpy as jnp

    jax_a = jnp.array(a.detach().numpy())
    jax_b = jnp.array(b.detach().numpy())

    def f(a):
        return jnp.sin((a * jax_b).sum(axis=[0])).sum()

    jit_f = jax.jit(jax.grad(f))
    jit_f(jax_a)
    begin = time.time()
    for _ in range(iters):
        out = jit_f(jax_a)
    out.block_until_ready()
    print(time.time() - begin)
    # for


bench(f)
bench(compiled_f)
# bench_jax()