File: qarithmetic_test.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (86 lines) | stat: -rw-r--r-- 2,875 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
from torch._ops import ops
import operator_benchmark as op_bench

qarithmetic_binary_configs = op_bench.cross_product_configs(
    N=(2, 8, 64, 512),
    dtype=(torch.quint8, torch.qint8, torch.qint32),
    contig=(False, True),
    tags=('short',)
)


qarithmetic_binary_ops = op_bench.op_list(
    attrs=(
        ('add', ops.quantized.add),
        ('add_relu', ops.quantized.add_relu),
        ('mul', ops.quantized.mul),
    ),
    attr_names=('op_name', 'op_func'),
)

qarithmetic_binary_scalar_ops = op_bench.op_list(
    attrs=(
        ('add_scalar', ops.quantized.add_scalar),
        ('mul_scalar', ops.quantized.mul_scalar),
    ),
    attr_names=('op_name', 'op_func'),
)

class _QFunctionalBinaryArithmeticBenchmarkBase(op_bench.TorchBenchmarkBase):
    def setup(self, N, dtype, contig):
        self.qfunctional = torch.ao.nn.quantized.QFunctional()

        # TODO: Consider more diverse shapes
        f_input = (torch.rand(N, N) - 0.5) * 256
        self.scale = 1.0
        self.zero_point = 0
        self.q_input_a = torch.quantize_per_tensor(f_input, scale=self.scale,
                                                   zero_point=self.zero_point,
                                                   dtype=dtype)

        if not contig:
            permute_dims = list(range(f_input.ndim))[::-1]
            self.q_input_a = self.q_input_a.permute(permute_dims)


class QFunctionalBenchmark(_QFunctionalBinaryArithmeticBenchmarkBase):
    def init(self, N, dtype, contig, op_func):
        super(QFunctionalBenchmark, self).setup(N, dtype, contig)
        self.inputs = {
            "q_input_a": self.q_input_a,
            "q_input_b": self.q_input_a,
            "scale": self.scale,
            "zero_point": self.zero_point
        }
        self.op_func = op_func

    def forward(self, q_input_a, q_input_b, scale: float, zero_point: int):
        return self.op_func(q_input_a, q_input_b, scale=scale, zero_point=zero_point)


op_bench.generate_pt_tests_from_op_list(qarithmetic_binary_ops,
                                        qarithmetic_binary_configs,
                                        QFunctionalBenchmark)


class QFunctionalScalarBenchmark(_QFunctionalBinaryArithmeticBenchmarkBase):
    def init(self, N, dtype, contig, op_func):
        super(QFunctionalScalarBenchmark, self).setup(N, dtype, contig)
        self.inputs = {
            "q_input": self.q_input_a,
            "scalar_input": 42
        }
        self.op_func = op_func

    def forward(self, q_input, scalar_input: int):
        return self.op_func(q_input, scalar_input)


op_bench.generate_pt_tests_from_op_list(qarithmetic_binary_scalar_ops,
                                        qarithmetic_binary_configs,
                                        QFunctionalScalarBenchmark)


if __name__ == '__main__':
    op_bench.benchmark_runner.main()