1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
|
# mypy: allow-untyped-defs
import abc
import collections
import itertools
import torch
from torch.nn.modules.module import _addindent
__all__ = [
"WeightedQuantizedModule",
]
class WeightedQuantizedModule(torch.nn.Module, metaclass=abc.ABCMeta):
"""Wrapper for quantized modules than can be lowered from reference modules."""
@classmethod
@abc.abstractmethod
def from_reference(cls, ref_module, output_scale, output_zero_point):
raise NotImplementedError
def _get_weight_observer(observer):
# FakeQuantize observer
if hasattr(observer, "activation_post_process"):
observer = observer.activation_post_process
# UniformQuantizationObserverBase observer
return observer
def _needs_weight_clamping(observer, dtype):
observer = _get_weight_observer(observer)
if dtype in [torch.qint8, torch.quint8, torch.qint32]:
info = torch.iinfo(dtype)
return observer.quant_min > info.min or observer.quant_max < info.max
return False
def _clamp_weights(qweight, observer, scale, zp):
if not _needs_weight_clamping(observer, qweight.dtype):
return qweight
observer = _get_weight_observer(observer)
min_, max_ = observer.quant_min, observer.quant_max
# Doing this because can't use torch.ops.quantized.clamp() with per_channel qscheme yet.
qw_int_max = torch.clone(qweight.int_repr()).fill_(max_)
qw_int_min = torch.clone(qweight.int_repr()).fill_(min_)
qw_int = torch.minimum(torch.maximum(qweight.int_repr(), qw_int_min), qw_int_max)
if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
qweight = torch._make_per_tensor_quantized_tensor(
qw_int, scale.item(), zp.item()
)
elif observer.qscheme in [
torch.per_channel_symmetric,
torch.per_channel_affine,
torch.per_channel_affine_float_qparams,
]:
qweight = torch._make_per_channel_quantized_tensor(
qw_int, scale, zp, axis=observer.ch_axis
)
else:
raise ValueError("Unexpected qscheme " + observer.qscheme)
return qweight
def _quantize_weight(float_wt, observer):
wt_scale, wt_zp = observer.calculate_qparams()
if observer.qscheme in [torch.per_tensor_symmetric, torch.per_tensor_affine]:
qweight = torch.quantize_per_tensor(
float_wt, float(wt_scale), int(wt_zp), torch.qint8
)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
elif observer.qscheme in [torch.per_channel_symmetric, torch.per_channel_affine]:
wt_axis = observer.ch_axis
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.double),
wt_zp.to(torch.int64),
wt_axis,
torch.qint8,
)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
elif observer.qscheme in [torch.per_channel_affine_float_qparams]:
qweight = torch.quantize_per_channel(
float_wt,
wt_scale.to(torch.float),
wt_zp.to(torch.float),
observer.ch_axis,
observer.dtype,
)
qweight = _clamp_weights(qweight, observer, wt_scale, wt_zp)
else:
raise ValueError("Unexpected qscheme " + observer.qscheme)
return qweight
def _ntuple_from_first(n):
"""Converts the argument to a tuple of size n
with the first element repeated."""
def parse(x):
while isinstance(x, collections.abc.Sequence):
if len(x) == n:
break
x = x[0]
return tuple(itertools.repeat(x, n))
return parse
def _hide_packed_params_repr(self, params):
# We don't want to show `PackedParams` children, hence custom
# `__repr__`. This is the same as nn.Module.__repr__, except the check
# for the `params module`.
extra_lines = []
extra_repr = self.extra_repr()
# empty string will be split into list ['']
if extra_repr:
extra_lines = extra_repr.split("\n")
child_lines = []
for key, module in self._modules.items():
if isinstance(module, params):
continue
mod_str = repr(module)
mod_str = _addindent(mod_str, 2)
child_lines.append("(" + key + "): " + mod_str)
lines = extra_lines + child_lines
main_str = self._get_name() + "("
if lines:
# simple one-liner info, which most builtin Modules will use
if len(extra_lines) == 1 and not child_lines:
main_str += extra_lines[0]
else:
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
_pair_from_first = _ntuple_from_first(2)
|