1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
|
import operator_benchmark as op_bench
import torch
"""Microbenchmarks for diag operator"""
# Configs for PT diag operator
diag_configs_short = op_bench.config_list(
attr_names=["dim", "M", "N", "diagonal", "out"],
attrs=[
[1, 64, 64, 0, True],
[2, 128, 128, -10, False],
[1, 256, 256, 20, True],
],
cross_product_configs={
"device": ["cpu", "cuda"],
},
tags=["short"],
)
class DiagBenchmark(op_bench.TorchBenchmarkBase):
def init(self, dim, M, N, diagonal, out, device):
self.inputs = {
"input": torch.rand(M, N, device=device)
if dim == 2
else torch.rand(M, device=device),
"diagonal": diagonal,
"out": out,
"out_tensor": torch.tensor(
(),
device=device,
),
}
self.set_module_name("diag")
def forward(self, input, diagonal: int, out: bool, out_tensor):
if out:
return torch.diag(input, diagonal=diagonal, out=out_tensor)
else:
return torch.diag(input, diagonal=diagonal)
op_bench.generate_pt_test(diag_configs_short, DiagBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()
|