1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
|
# mypy: allow-untyped-defs
from typing import List
import torch
from torch.nn.parameter import Parameter
__all__: List[str] = []
class _LearnableFakeQuantize(torch.ao.quantization.FakeQuantizeBase):
r"""Generalized extension of the FakeQuantize module in fake_quantize.py.
This is an extension of the FakeQuantize module in fake_quantize.py, which
supports more generalized lower-bit quantization and supports learning of the scale
and zero point parameters through backpropagation.
In addition to the attributes in the original FakeQuantize module, the _LearnableFakeQuantize
module also includes the following attributes to support quantization parameter learning.
* :attr:`channel_len` defines the length of the channel when initializing scale and zero point
for the per channel case.
* :attr:`use_grad_scaling` defines the flag for whether the gradients for scale and zero point are
normalized by the constant, which is proportional to the square root of the number of
elements in the tensor. The related literature justifying the use of this particular constant
can be found here: https://openreview.net/pdf?id=rkgO66VKDS.
* :attr:`fake_quant_enabled` defines the flag for enabling fake quantization on the output.
* :attr:`static_enabled` defines the flag for using observer's static estimation for
scale and zero point.
* :attr:`learning_enabled` defines the flag for enabling backpropagation for scale and zero point.
"""
def __init__(
self,
observer,
quant_min=0,
quant_max=255,
scale=1.0,
zero_point=0.0,
channel_len=-1,
use_grad_scaling=False,
**observer_kwargs,
):
super().__init__()
assert quant_min < quant_max, "quant_min must be strictly less than quant_max."
self.quant_min = quant_min
self.quant_max = quant_max
# also pass quant_min and quant_max to observer
observer_kwargs["quant_min"] = quant_min
observer_kwargs["quant_max"] = quant_max
self.use_grad_scaling = use_grad_scaling
if channel_len == -1:
self.scale = Parameter(torch.tensor([scale]))
self.zero_point = Parameter(torch.tensor([zero_point]))
else:
assert (
isinstance(channel_len, int) and channel_len > 0
), "Channel size must be a positive integer."
self.scale = Parameter(torch.tensor([scale] * channel_len))
self.zero_point = Parameter(torch.tensor([zero_point] * channel_len))
self.activation_post_process = observer(**observer_kwargs)
assert (
torch.iinfo(self.activation_post_process.dtype).min <= quant_min
), "quant_min out of bound"
assert (
quant_max <= torch.iinfo(self.activation_post_process.dtype).max
), "quant_max out of bound"
self.dtype = self.activation_post_process.dtype
self.qscheme = self.activation_post_process.qscheme
self.ch_axis = (
self.activation_post_process.ch_axis
if hasattr(self.activation_post_process, "ch_axis")
else -1
)
self.register_buffer("fake_quant_enabled", torch.tensor([1], dtype=torch.uint8))
self.register_buffer("static_enabled", torch.tensor([1], dtype=torch.uint8))
self.register_buffer("learning_enabled", torch.tensor([0], dtype=torch.uint8))
bitrange = torch.tensor(quant_max - quant_min + 1).double()
self.bitwidth = int(torch.log2(bitrange).item())
self.register_buffer("eps", torch.tensor([torch.finfo(torch.float32).eps]))
@torch.jit.export
def enable_param_learning(self):
r"""Enable parameter learning over static observer estimates.
Enables learning of quantization parameters and
disables static observer estimates. Forward path returns fake quantized X.
"""
self.toggle_qparam_learning(enabled=True).toggle_fake_quant(
enabled=True
).toggle_observer_update(enabled=False)
return self
@torch.jit.export
def enable_static_estimate(self):
"""Enable static estimates of quantization parameters.
Enables static observer estimates and disables learning of
quantization parameters. Forward path returns fake quantized X.
"""
self.toggle_qparam_learning(enabled=False).toggle_fake_quant(
enabled=True
).toggle_observer_update(enabled=True)
@torch.jit.export
def enable_static_observation(self):
"""Enable accumulation of data without updating quantization parameters.
Enables static observer accumulating data from input but doesn't
update the quantization parameters. Forward path returns the original X.
"""
self.toggle_qparam_learning(enabled=False).toggle_fake_quant(
enabled=False
).toggle_observer_update(enabled=True)
@torch.jit.export
def toggle_observer_update(self, enabled=True):
self.static_enabled[0] = int(enabled) # type: ignore[operator]
return self
@torch.jit.export
def enable_observer(self, enabled=True):
self.toggle_observer_update(enabled)
@torch.jit.export
def toggle_qparam_learning(self, enabled=True):
self.learning_enabled[0] = int(enabled) # type: ignore[operator]
self.scale.requires_grad = enabled
self.zero_point.requires_grad = enabled
return self
@torch.jit.export
def toggle_fake_quant(self, enabled=True):
self.fake_quant_enabled[0] = int(enabled)
return self
@torch.jit.export
def observe_quant_params(self):
print(f"_LearnableFakeQuantize Scale: {self.scale.detach()}")
print(f"_LearnableFakeQuantize Zero Point: {self.zero_point.detach()}")
@torch.jit.export
def calculate_qparams(self):
self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator]
scale = self.scale.detach()
zero_point = (
self.zero_point.detach()
.round()
.clamp(self.quant_min, self.quant_max)
.long()
)
return scale, zero_point
def forward(self, X):
if self.static_enabled[0] == 1: # type: ignore[index]
self.activation_post_process(X.detach())
_scale, _zero_point = self.activation_post_process.calculate_qparams()
_scale = _scale.to(self.scale.device)
_zero_point = _zero_point.to(self.zero_point.device)
self.scale.data.copy_(_scale)
self.zero_point.data.copy_(_zero_point)
else:
self.scale.data.clamp_(min=self.eps.item()) # type: ignore[operator]
if self.fake_quant_enabled[0] == 1:
if self.qscheme in (
torch.per_channel_symmetric,
torch.per_tensor_symmetric,
):
self.zero_point.data.zero_()
if self.use_grad_scaling:
grad_factor = 1.0 / (X.numel() * self.quant_max) ** 0.5
else:
grad_factor = 1.0
if self.qscheme in (torch.per_channel_symmetric, torch.per_channel_affine):
X = torch._fake_quantize_learnable_per_channel_affine(
X,
self.scale,
self.zero_point,
self.ch_axis,
self.quant_min,
self.quant_max,
grad_factor,
)
else:
X = torch._fake_quantize_learnable_per_tensor_affine(
X,
self.scale,
self.zero_point,
self.quant_min,
self.quant_max,
grad_factor,
)
return X
|