File: adaround_fake_quantize.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (150 lines) | stat: -rw-r--r-- 6,051 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import Tuple

import torch
from torch.ao.quantization.fake_quantize import _is_symmetric_quant
from torch.ao.quantization.utils import is_per_tensor
from torch.quantization import FakeQuantize
from torch.quantization.observer import MinMaxObserver


class AdaroundFakeQuantizer(FakeQuantize):
    """
    This is a FakeQuantizer that enables an adaptive rounding fake quantizer.
    Adaround is a technique to adaptively round weights, derived from the paper https://arxiv.org/pdf/2004.10568.pdf
    For HTP compatibility, we are targeting to use symmetric quantization
    """

    scale: torch.Tensor
    zero_point: torch.Tensor
    V: torch.nn.Parameter

    # pyre-fixme[3]: Return type must be annotated.
    def __init__(
        self,
        observer=MinMaxObserver,
        qscheme=torch.per_tensor_symmetric,  # not used, but needed for fakequant
        quant_min: int = -128,
        quant_max: int = 127,
        ch_axis: int = 0,
        # pyre-fixme[2]: Parameter must be annotated.
        **observer_kwargs,
    ):
        super().__init__(
            observer=observer,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
            is_dynamic=False,
            **observer_kwargs,
        )
        # Populate quant_min/quant_max to observer_kwargs if valid
        if quant_min is not None and quant_max is not None:
            assert (
                quant_min <= quant_max
            ), "quant_min must be less than or equal to quant_max"
        # pyre-fixme[4]: Attribute must be annotated.
        self.qscheme = qscheme
        self.is_per_tensor: bool = is_per_tensor(qscheme)
        self.is_symmetric: bool = _is_symmetric_quant(qscheme)
        assert self.is_symmetric, "Only symmetric quantization is supported"
        self.ch_axis: int = ch_axis

        self.scale = torch.tensor([], requires_grad=False)
        self.zero_point = torch.tensor([], requires_grad=False)
        self.V = torch.nn.Parameter(torch.tensor([]), requires_grad=True)
        # Fixed Stretch parameters
        self.zeta: torch.Tensor = torch.tensor(1.1, requires_grad=False)
        self.gamma: torch.Tensor = torch.tensor(-0.1, requires_grad=False)
        self.sigmoid = torch.nn.Sigmoid()
        self.use_soft_rounding = True

    @torch.jit.export
    def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.scale, self.zero_point

    @torch.jit.export
    def extra_repr(self) -> str:
        return (
            f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
            f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, "
            f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, "
            f"scale={self.scale}, zero_point={self.zero_point}, (self.V >= 0).int().sum()={(self.V >= 0).int().sum()}"
        )

    def enable_weight_fake_quant(self) -> None:
        self.fake_quant_enabled[0] = 1

    def get_rectified_sigmoid_func(self) -> torch.Tensor:
        if self.use_soft_rounding:
            return torch.clamp(
                self.sigmoid(self.V) * (self.zeta - self.gamma) + self.gamma,
                min=0,
                max=1,
            )
        else:
            # This will dump a binary solution
            return (self.V >= 0).int()

    @torch.jit.ignore
    def update_scale(
        self, X: torch.Tensor, _scale: torch.Tensor, _zero_point: torch.Tensor
    ) -> None:
        if self.scale.numel() == 0:
            self.scale.data = _scale.to(X.device)
            self.zero_point = _zero_point.to(X.device)
        else:
            self.scale.data = _scale
            if not self.is_symmetric:
                self.zero_point = _zero_point
            else:
                self.zero_point = torch.zeros_like(_zero_point)
            for i in range(X.dim()):
                if i == self.ch_axis:
                    continue
                self.zero_point = self.zero_point.unsqueeze(i)
        X_q = X / self.scale
        X_q_floor = torch.floor(X_q)
        residual = X_q - X_q_floor  # [0,1)
        assert torch.all(
            torch.ge(residual, 0)
        ), "residual should be non-negative [0, 1)"
        V_init = -torch.log((self.zeta - self.gamma) / (residual - self.gamma) - 1)
        self.V.data = V_init

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        if self.observer_enabled[0] == 1:
            X_detached = X.detach()
            self.activation_post_process(X_detached)
            _scale, _zero_point = self.activation_post_process.calculate_qparams()
            _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(
                self.zero_point.device
            )
            dims = list(range(X.dim()))
            if not self.is_per_tensor:
                dims.remove(self.ch_axis)
            if not self.is_per_tensor:
                for i in range(X.dim()):
                    if i == self.ch_axis:
                        continue
                    _scale = _scale.unsqueeze(i)
                    _zero_point = _zero_point.unsqueeze(i)
            self.update_scale(X_detached, _scale, _zero_point)

        if self.fake_quant_enabled[0] == 1:
            # Perform soft quantization
            # See the equation (23) in Adaround paper
            h_v = self.get_rectified_sigmoid_func()
            X_q = X / self.scale
            # Straight-Through Estimator for floor function
            X_q_floor = torch.floor(X_q) + self.zero_point
            # Regardless of rounding, gradient should be able to flow back to self.V from X_q_dq.
            # With adaround, we don't train weight, but train V only.
            X_q_dq = (
                torch.clamp(X_q_floor + h_v, min=self.quant_min, max=self.quant_max)
                - self.zero_point
            ) * self.scale
            return X_q_dq
        else:
            return X