1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
|
import operator_benchmark as op_bench
import torch
import torch.nn as nn
from pt import configs
"""Microbenchmarks for Linear operator."""
class LinearBenchmark(op_bench.TorchBenchmarkBase):
def init(self, N, IN, OUT, device):
self.inputs = {
"input_one": torch.rand(N, IN, device=device)
}
self.linear = nn.Linear(IN, OUT).to(device=device)
self.set_module_name("linear")
def forward(self, input_one):
return self.linear(input_one)
op_bench.generate_pt_test(configs.linear_configs_short + configs.linear_configs_long,
LinearBenchmark)
if __name__ == "__main__":
op_bench.benchmark_runner.main()
|