1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
|
# mypy: allow-untyped-defs
import torch
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.quantized as nnq
from torch.ao.nn.quantized.modules.utils import _quantize_weight
__all__ = [
"Linear",
]
class Linear(nnq.Linear):
r"""
A dynamic quantized linear module with floating point tensor as inputs and outputs.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
Similar to :class:`torch.nn.Linear`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module which are of
shape :math:`(\text{out\_features}, \text{in\_features})`.
bias (Tensor): the non-learnable floating point bias of the module of shape
:math:`(\text{out\_features})`. If :attr:`bias` is ``True``,
the values are initialized to zero.
Examples::
>>> # xdoctest: +SKIP
>>> m = nn.quantized.dynamic.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
# version used in this class is different from the parent class nnq.Linear
_version = 4
def __init__(self, in_features, out_features, bias_=True, dtype=torch.qint8):
super().__init__(in_features, out_features, bias_, dtype=dtype)
# We don't muck around with buffers or attributes or anything here
# to keep the module simple. *everything* is simply a Python attribute.
# Serialization logic is explicitly handled in the below serialization and
# deserialization modules
self.version = 4
def forward(self, x):
# Note that we can handle self.bias == None case.
if self._packed_params.dtype == torch.qint8:
if self.version is None or self.version < 4:
Y = torch.ops.quantized.linear_dynamic(
x, self._packed_params._packed_params
)
else:
Y = torch.ops.quantized.linear_dynamic(
x, self._packed_params._packed_params, reduce_range=True
)
elif self._packed_params.dtype == torch.float16:
Y = torch.ops.quantized.linear_dynamic_fp16(
x, self._packed_params._packed_params
)
else:
raise RuntimeError("Unsupported dtype on dynamic quantized linear!")
return Y.to(x.dtype)
def _get_name(self):
return "DynamicQuantizedLinear"
def extra_repr(self):
extra_repr_str = f"in_features={self.in_features}, out_features={self.out_features}, dtype={self._packed_params.dtype}"
if self._packed_params.dtype == torch.qint8:
extra_repr_str += f", qscheme={self.weight().qscheme()}"
return extra_repr_str
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
self.version = version
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
False,
missing_keys,
unexpected_keys,
error_msgs,
)
@classmethod
def from_float(cls, mod, use_precomputed_fake_quant=False):
r"""Create a dynamic quantized module from a float module or qparams_dict
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
float_modules = [
torch.nn.Linear,
torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
torch.ao.nn.intrinsic.modules.fused.LinearReLU,
torch.ao.nn.qat.dynamic.Linear,
]
assert (
type(mod) in float_modules
), "nn.quantized.dynamic.Linear.from_float only works for one of" + str(
[float_mod.__name__ for float_mod in float_modules]
)
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) == nni.LinearReLU:
mod = mod[0]
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
# We have the circular import issues if we import the qconfig in the beginning of this file:
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
# import until we need it.
from torch.ao.quantization.qconfig import default_dynamic_qconfig
weight_observer = default_dynamic_qconfig.weight()
dtype = weight_observer.dtype
assert dtype in [torch.qint8, torch.float16], (
"The only supported dtypes for "
f"dynamic quantized linear are qint8 and float16 got: {dtype}"
)
weight_observer(mod.weight)
if dtype == torch.qint8:
qweight = _quantize_weight(mod.weight.float(), weight_observer)
elif dtype == torch.float16:
qweight = mod.weight.float()
else:
raise RuntimeError(
"Unsupported dtype specified for dynamic quantized Linear!"
)
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
qlinear.set_weight_bias(qweight, mod.bias)
return qlinear
@classmethod
def from_reference(cls, ref_qlinear):
"""Create a (fbgemm/qnnpack) dynamic quantized module from a reference quantized
module
Args:
ref_qlinear (Module): a reference quantized module, either produced by
torch.ao.quantization functions or provided by the user
"""
qlinear = cls(
ref_qlinear.in_features,
ref_qlinear.out_features,
dtype=ref_qlinear.weight_dtype,
)
qweight = ref_qlinear.get_quantized_weight()
bias = ref_qlinear.bias
qlinear.set_weight_bias(qweight, bias)
return qlinear
|