1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
|
import timeit
from functorch.compile import compiled_module, tvm_compile
import torch.nn as nn
import torch
def nop(f, _):
return f
fw_compiler = tvm_compile(target='llvm', tuning_logfile='fw_keops')
bw_compiler = tvm_compile(target='llvm', tuning_logfile='bw_keops')
fw_compiler = nop
bw_compiler = nop
def run(mod, input):
out = mod(input)
out.sum().backward()
grads = [p.grad for p in mod.parameters()]
return (out, *grads)
class Foo(nn.Module):
def __init__(self):
super(Foo, self).__init__()
self.param = nn.Parameter(torch.randn(1))
self.register_buffer("buf", torch.randn(1))
def forward(self, x):
return (self.param * x + self.buf).sum(dim=0)
input = torch.randn(1)
mod = Foo()
compiled_mod = compiled_module(mod, fw_compiler, bw_compiler)
for a, b in zip(run(mod, input), run(compiled_mod, input)):
torch.testing.assert_allclose(a, b)
out = mod(input)
out.sum().backward()
mod.param.data -= mod.param.grad
compiled_mod.orig_module.param.data -= compiled_mod.orig_module.param.grad
compiled_mod.orig_module.param.grad = None
for a, b in zip(run(mod, input), run(compiled_mod, input)):
torch.testing.assert_allclose(a, b)
for _ in range(5):
i = 10000
t = timeit.Timer("mod(input)", globals=globals()).timeit(10000)
print(f"eager {t/i*1e6}")
t = timeit.Timer("compiled_mod(input)", globals=globals()).timeit(10000)
print(f"compiled {t/i*1e6}")
|