File: qactivation_test.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (117 lines) | stat: -rw-r--r-- 3,400 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import operator_benchmark as op_bench

import torch
import torch.ao.nn.quantized.functional as qF


r"""Microbenchmarks for the quantized activations."""

qactivation_long_configs = op_bench.cross_product_configs(
    dims=(
        # VGG-16 relu's with original shape: (-1, 3, 224, 224)
        (64, 224, 224),  # ReLU-1   # noqa: E201
        (128, 112, 112),  # ReLU-6
        (256, 56, 56),  # ReLU-11  # noqa: E241
        (512, 28, 28),  # ReLU-18  # noqa: E241
        (512, 14, 14),  # ReLU-25  # noqa: E241
        # Batch = 16
        (16, 64, 224, 224),  # ReLU-1   # noqa: E241
        (16, 128, 112, 112),  # ReLU-6
        (16, 256, 56, 56),  # ReLU-11  # noqa: E241
        (16, 512, 28, 28),  # ReLU-18  # noqa: E241
        (16, 512, 14, 14),  # ReLU-25  # noqa: E241
    ),
    contig=(False, True),
    inplace=(False, True),
    dtype=(torch.quint8,),
    tags=("long",),
)

qactivation_short_configs = op_bench.cross_product_configs(
    dims=(
        (3, 4, 5),  # Rank=3
        (2, 3, 4, 5),  # Rank=4,
        # Dimensions from the floating point benchmarks
        (512, 512),
        (256, 1024),
    ),
    contig=(False,),
    inplace=(False,),
    dtype=(torch.quint8, torch.qint8, torch.qint32),
    tags=("short",),
)

qactivation_ops = op_bench.op_list(
    attrs=(
        ("relu", torch.nn.ReLU()),
        ("relu6", torch.ops.quantized.relu6),
        ("functional.hardtanh", qF.hardtanh),
        ("functional.hardsigmoid", qF.hardsigmoid),
        ("functional.leaky_relu", qF.leaky_relu),
        ("functional.sigmoid", torch.nn.functional.sigmoid),
        ("functional.tanh", torch.nn.functional.tanh),
    ),
    attr_names=("op_name", "op_func"),
)


class QActivationBenchmarkBase(op_bench.TorchBenchmarkBase):
    r"""Base class for all the activations."""

    def _setup(self, dims, contig, dtype):
        # Input
        f_input = (torch.rand(*dims) - 0.5) * 256
        self.scale = 1.0
        self.zero_point = 0

        # Quantize the tensor
        q_input = torch.quantize_per_tensor(
            f_input, scale=self.scale, zero_point=self.zero_point, dtype=dtype
        )
        if not contig:
            # Make non-contiguous
            new_shape = list(range(q_input.ndim))[::-1]
            q_input = q_input.permute(new_shape)

        self.inputs = {"q_input": q_input}

    def init(self, dims, contig, inplace, dtype, op_func):
        self._setup(dims, contig, dtype)
        self.qop = op_func


class QActivationBenchmark(QActivationBenchmarkBase):
    def forward(self, q_input):
        return self.qop(q_input)


op_bench.generate_pt_tests_from_op_list(
    qactivation_ops,
    qactivation_short_configs + qactivation_long_configs,
    QActivationBenchmark,
)


qactivation_scale_zero_point_ops = op_bench.op_list(
    attrs=(
        ("functional.hardswish", qF.hardswish),
        ("functional.elu", qF.elu),
        ("functional.celu", qF.celu),
    ),
    attr_names=("op_name", "op_func"),
)


class QActivationScaleZeroPointBenchmark(QActivationBenchmarkBase):
    def forward(self, q_input):
        return self.qop(q_input, scale=self.scale, zero_point=self.zero_point)


op_bench.generate_pt_tests_from_op_list(
    qactivation_scale_zero_point_ops,
    qactivation_short_configs + qactivation_long_configs,
    QActivationScaleZeroPointBenchmark,
)

if __name__ == "__main__":
    op_bench.benchmark_runner.main()