1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
|
from collections.abc import Iterable
import torch
import torch.nn as nn
import torch.nn.intrinsic as nni
import torch.nn.intrinsic.qat as nniqat
from torch.nn.utils.fusion import fuse_linear_bn_weights
from torch.nn.utils.parametrize import type_before_parametrizations
from typing import Optional
from .utils import _quantize_weight, hide_packed_params_repr, WeightedQuantizedModule
__all__ = ['LinearPackedParams', 'Linear']
class LinearPackedParams(torch.nn.Module):
_version = 3
def __init__(self, dtype=torch.qint8):
super().__init__()
self.dtype = dtype
if self.dtype == torch.qint8:
wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
elif self.dtype == torch.float16:
wq = torch.zeros([1, 1], dtype=torch.float)
self.set_weight_bias(wq, None)
@torch.jit.export
def set_weight_bias(self, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> None:
if self.dtype == torch.qint8:
self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
elif self.dtype == torch.float16:
self._packed_params = torch.ops.quantized.linear_prepack_fp16(weight, bias)
else:
raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
@torch.jit.export
def _weight_bias(self):
if self.dtype == torch.qint8:
return torch.ops.quantized.linear_unpack(self._packed_params)
elif self.dtype == torch.float16:
return torch.ops.quantized.linear_unpack_fp16(self._packed_params)
else:
raise RuntimeError('Unsupported dtype on dynamic quantized linear!')
def forward(self, x):
return x
# Version 1
# self
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 2
# self
# |--- weight : Tensor
# |--- bias : Tensor
# |--- dtype : torch.dtype
#
# Version 3
# self
# |--- _packed_params : (Tensor, Tensor) representing (weight, bias)
# of LinearPackedParams
# |--- dtype : torch.dtype
def _save_to_state_dict(self, destination, prefix, keep_vars):
super(LinearPackedParams, self)._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'dtype'] = self.dtype
destination[prefix + '_packed_params'] = self._weight_bias()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
version = local_metadata.get('version', None)
if version is None or version < 2:
self.dtype = torch.qint8
else:
self.dtype = state_dict[prefix + 'dtype']
state_dict.pop(prefix + 'dtype')
if version is None or version < 3:
self.set_weight_bias(state_dict[prefix + 'weight'], state_dict[prefix + 'bias'])
state_dict.pop(prefix + 'weight')
state_dict.pop(prefix + 'bias')
if version == 3:
weight, bias = state_dict[prefix + '_packed_params']
state_dict.pop(prefix + '_packed_params')
self.set_weight_bias(weight, bias)
super(LinearPackedParams, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
def __repr__(self):
return self._weight_bias().__repr__()
class Linear(WeightedQuantizedModule):
r"""
A quantized linear module with quantized tensor as inputs and outputs.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
Similar to :class:`~torch.nn.Linear`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
weight (Tensor): the non-learnable quantized weights of the module of
shape :math:`(\text{out\_features}, \text{in\_features})`.
bias (Tensor): the non-learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized to zero.
scale: `scale` parameter of output Quantized Tensor, type: double
zero_point: `zero_point` parameter for output Quantized Tensor, type: long
Examples::
>>> m = nn.quantized.Linear(20, 30)
>>> input = torch.randn(128, 20)
>>> # xdoctest: +SKIP
>>> input = torch.quantize_per_tensor(input, 1.0, 0, torch.quint8)
>>> output = m(input)
>>> print(output.size())
torch.Size([128, 30])
"""
_version = 3
_FLOAT_MODULE = (nn.Linear, nn.modules.linear.NonDynamicallyQuantizableLinear)
def __init__(self, in_features, out_features, bias_=True,
dtype=torch.qint8):
super().__init__()
# We don't muck around with buffers or attributes or anything here
# to keep the module simple. *everything* is simply a Python attribute.
# Serialization logic is explicitly handled in the below serialization and
# deserialization modules
self.in_features = in_features
self.out_features = out_features
bias = None
if bias_:
bias = torch.zeros(out_features, dtype=torch.float)
if dtype == torch.qint8:
qweight = torch._empty_affine_quantized(
[out_features, in_features], scale=1, zero_point=0, dtype=torch.qint8)
elif dtype == torch.float16:
qweight = torch.zeros([out_features, in_features], dtype=torch.float)
else:
raise RuntimeError('Unsupported dtype specified for quantized Linear!')
self._packed_params = LinearPackedParams(dtype)
self._packed_params.set_weight_bias(qweight, bias)
self.scale = 1.0
self.zero_point = 0
def _get_name(self):
return 'QuantizedLinear'
def extra_repr(self):
return 'in_features={}, out_features={}, scale={}, zero_point={}, qscheme={}'.format(
self.in_features, self.out_features, self.scale, self.zero_point, self.weight().qscheme()
)
def __repr__(self):
return hide_packed_params_repr(self, LinearPackedParams)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.quantized.linear(
x, self._packed_params._packed_params, self.scale, self.zero_point)
# ===== Serialization methods =====
# The special consideration here is that we have to unpack the weights into their
# regular QTensor form for serialization. Packed weights should not live
# outside the process in which they were created, rather they should be derived
# from the QTensor weight.
#
# Version 1
# self
# |--- scale : float
# |--- zero_point : int
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 2
# self
# |--- scale : float
# |--- zero_point : int
# |--- _packed_params : Module
# |--- weight : Tensor
# |--- bias : Tensor
#
# Version 3
# self
# |--- scale : float
# |--- zero_point : int
# |--- _packed_params : Module
# |--- _packed_params : (Tensor, Tensor) representing weight, bias
# of LinearPackedParams C++ struct
#
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'scale'] = torch.tensor(self.scale)
destination[prefix + 'zero_point'] = torch.tensor(self.zero_point)
# ===== Deserialization methods =====
# Counterpart to the serialization methods, we must pack the serialized QTensor
# weight into its packed format for use by the FBGEMM ops.
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
self.scale = float(state_dict[prefix + 'scale'])
state_dict.pop(prefix + 'scale')
self.zero_point = int(state_dict[prefix + 'zero_point'])
state_dict.pop(prefix + 'zero_point')
version = local_metadata.get('version', None)
if version is None or version == 1:
# We moved the parameters into a LinearPackedParameters submodule
weight = state_dict.pop(prefix + 'weight')
bias = state_dict.pop(prefix + 'bias')
state_dict.update({prefix + '_packed_params.weight': weight,
prefix + '_packed_params.bias': bias})
super()._load_from_state_dict(
state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
# Function rather than property to make sure that JIT serialization doesn't
# register this as an attribute
def _weight_bias(self):
return self._packed_params._weight_bias()
def weight(self):
return self._weight_bias()[0]
def bias(self):
return self._weight_bias()[1]
def set_weight_bias(self, w: torch.Tensor, b: Optional[torch.Tensor]) -> None:
self._packed_params.set_weight_bias(w, b)
@classmethod
def from_float(cls, mod):
r"""Create a quantized module from an observed float module
Args:
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
if hasattr(mod, 'weight_fake_quant'):
if type_before_parametrizations(mod) == nniqat.LinearBn1d:
mod.weight, mod.bias = fuse_linear_bn_weights(
mod.weight, mod.bias, mod.bn.running_mean, mod.bn.running_var,
mod.bn.eps, mod.bn.weight, mod.bn.bias)
weight_post_process = mod.weight_fake_quant
activation_post_process = mod.activation_post_process
else:
# This function does not participate in JIT, so it is OK to ignore
# the type mismatch in assignment. Also, mypy has an issue with
# iterables not being implemented, so we are ignoring those too.
if not isinstance(cls._FLOAT_MODULE, Iterable):
cls._FLOAT_MODULE = [cls._FLOAT_MODULE] # type: ignore[assignment]
supported_modules = ', '.join([float_mod.__name__ for float_mod in cls._FLOAT_MODULE]) # type: ignore[attr-defined]
error_msg = 'nnq.{}.from_float only works for {}, but got: {}'.format(cls.__name__, supported_modules, type(mod))
assert type_before_parametrizations(mod) in cls._FLOAT_MODULE, error_msg.format() # type: ignore[attr-defined]
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
activation_post_process = mod.activation_post_process
if type_before_parametrizations(mod) == nni.LinearReLU:
mod = mod[0]
weight_post_process = mod.qconfig.weight()
weight_post_process(mod.weight)
dtype = weight_post_process.dtype
act_scale, act_zp = activation_post_process.calculate_qparams()
assert dtype == torch.qint8, 'Weight observer must have dtype torch.qint8'
qweight = _quantize_weight(mod.weight.float(), weight_post_process)
qlinear = cls(mod.in_features,
mod.out_features,
dtype=dtype)
qlinear.set_weight_bias(qweight, mod.bias)
qlinear.scale = float(act_scale)
qlinear.zero_point = int(act_zp)
return qlinear
@classmethod
def from_reference(cls, ref_qlinear, output_scale, output_zero_point):
r"""Create a (fbgemm/qnnpack) quantized module from a reference quantized module
Args:
ref_qlinear (Module): a reference quantized linear module, either produced by torch.ao.quantization
utilities or provided by the user
output_scale (float): scale for output Tensor
zero_point (int): zero point for output Tensor
"""
qlinear = cls(
ref_qlinear.in_features,
ref_qlinear.out_features)
qweight = ref_qlinear.get_quantized_weight()
qlinear.set_weight_bias(qweight, ref_qlinear.bias)
qlinear.scale = float(output_scale)
qlinear.zero_point = int(output_zero_point)
return qlinear
|