
import operator_benchmark as op_bench
import torch


"""Microbenchmarks for point-wise unary operator."""


# Configs for pointwise unary ops
unary_ops_configs_short = op_bench.config_list(
    attr_names=['M', 'N'],
    attrs=[
        [512, 512],
    ],
    cross_product_configs={
        'device': ['cpu', 'cuda'],
    },
    tags=['short']
)

unary_ops_configs_long = op_bench.cross_product_configs(
    M=[256, 1024],
    N=[256, 1024],
    device=['cpu', 'cuda'],
    tags=['long']
)

class UnaryOpBenchmark(op_bench.TorchBenchmarkBase):
    def init(self, M, N, device, op_func):
        self.inputs = {
            "input": torch.rand(M, N, device=device)
        }
        self.op_func = op_func

    def forward(self, input):
        return self.op_func(input)

def bernoulli_(input):
    return input.bernoulli_()

def cauchy_(input):
    return input.cauchy_()

def digamma_(input):
    return input.digamma_()

def exponential_(input):
    return input.exponential_()

def normal_(input):
    return input.normal_()

def random_(input):
    return input.random_()

def sign_(input):
    return input.sign_()

def uniform_(input):
    return input.uniform_()

def half_(input):
    return input.half()

def long_(input):
    return input.long()

unary_ops_list = op_bench.op_list(
    attr_names=['op_name', 'op_func'],
    attrs=[
        ['abs', torch.abs],
        ['abs_', torch.abs_],
        ['acos', torch.acos],
        ['acos_', torch.acos_],
        ['argsort', torch.argsort],
        ['asin', torch.asin],
        ['asin_', torch.asin_],
        ['atan', torch.atan],
        ['atan_', torch.atan_],
        ['ceil', torch.ceil],
        ['ceil_', torch.ceil_],
        ['clone', torch.clone],
        ['cos', torch.cos],
        ['cos_', torch.cos_],
        ['cosh', torch.cosh],
        ['digamma', torch.digamma],
        ['erf', torch.erf],
        ['erf_', torch.erf_],
        ['erfc', torch.erfc],
        ['erfc_', torch.erfc_],
        ['erfinv', torch.erfinv],
        ['exp', torch.exp],
        ['exp_', torch.exp_],
        ['expm1', torch.expm1],
        ['expm1_', torch.expm1_],
        ['floor', torch.floor],
        ['floor_', torch.floor_],
        ['frac', torch.frac],
        ['frac_', torch.frac_],
        ['hardshrink', torch.hardshrink],
        ['lgamma', torch.lgamma],
        ['log', torch.log],
        ['log10', torch.log10],
        ['log10_', torch.log10_],
        ['log1p', torch.log1p],
        ['log1p_', torch.log1p_],
        ['log2', torch.log2],
        ['log2_', torch.log2_],
        ['log_', torch.log_],
        ['logit', torch.logit],
        ['logit_', torch.logit_],
        ['neg', torch.neg],
        ['neg_', torch.neg_],
        ['reciprocal', torch.reciprocal],
        ['reciprocal_', torch.reciprocal_],
        ['relu', torch.relu],
        ['relu_', torch.relu_],
        ['round', torch.round],
        ['round_', torch.round_],
        ['rsqrt', torch.rsqrt],
        ['rsqrt_', torch.rsqrt_],
        ['sigmoid', torch.sigmoid],
        ['sigmoid_', torch.sigmoid_],
        ['sign', torch.sign],
        ['sgn', torch.sgn],
        ['sin', torch.sin],
        ['sin_', torch.sin_],
        ['sinh', torch.sinh],
        ['sqrt', torch.sqrt],
        ['sqrt_', torch.sqrt_],
        ['square', torch.square],
        ['square_', torch.square_],
        ['tan', torch.tan],
        ['tan_', torch.tan_],
        ['tanh', torch.tanh],
        ['tanh_', torch.tanh_],
        ['trunc', torch.trunc],
        ['trunc_', torch.trunc_],
        ['unique', torch.functional._return_output],
        ['zero_', torch.zero_],
        ['bernoulli_', bernoulli_],
        ['cauchy_', cauchy_],
        ['digamma_', digamma_],
        ['exponential_', exponential_],
        ['normal_', normal_],
        ['random_', random_],
        ['sign_', sign_],
        ['uniform_', uniform_],
        ['half', half_],
        ['long', long_],
    ],
)


op_bench.generate_pt_tests_from_op_list(unary_ops_list,
                                        unary_ops_configs_short + unary_ops_configs_long,
                                        UnaryOpBenchmark)


if __name__ == "__main__":
    op_bench.benchmark_runner.main()
