1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
|
import torch
import typing
class ReferenceQuantizedModule(torch.nn.Module):
def _init_weight_qparams(self, weight_qparams, device):
if weight_qparams is None:
weight_qparams = {
"qscheme": torch.per_tensor_affine,
"dtype": torch.quint8,
"scale": 1.0,
"zero_point": 0
}
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
self.weight_dtype = weight_qparams["dtype"]
assert self.weight_qscheme in [
None, torch.per_tensor_affine, torch.per_channel_affine,
torch.per_channel_affine_float_qparams], \
Exception(f"qscheme: {self.weight_qscheme} is not support in reference quantized {self._get_name()}")
if self.weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
zero_point_dtype = weight_qparams["zero_point"].dtype if \
isinstance(weight_qparams["zero_point"], torch.Tensor) else \
torch.int
w_scale = weight_qparams["scale"]
w_scale_tensor = w_scale.clone().detach() \
if isinstance(w_scale, torch.Tensor) \
else torch.tensor(w_scale, dtype=torch.float, device=device)
self.register_buffer("weight_scale", w_scale_tensor)
w_zp = weight_qparams["zero_point"]
w_zp_tensor = w_zp.clone().detach() \
if isinstance(w_zp, torch.Tensor) \
else torch.tensor(w_zp, dtype=zero_point_dtype, device=device)
self.register_buffer("weight_zero_point", w_zp_tensor)
if self.weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
w_axis = weight_qparams["axis"]
w_axis_tensor = w_axis.clone().detach() \
if isinstance(w_axis, torch.Tensor) \
else torch.tensor(w_axis, dtype=torch.int, device=device)
self.register_buffer("weight_axis", w_axis_tensor)
else:
# added for TorchScriptability, not used
self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device))
else:
# added for TorchScriptability, and for torch.float
self.register_buffer("weight_scale", torch.tensor(1.0, dtype=torch.float, device=device))
self.register_buffer("weight_zero_point", torch.tensor(0, dtype=torch.int, device=device))
self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device))
def get_weight(self):
"""
Fake quantize (quantize and dequantize) the weight with
the quantization parameters for weight, this is used to
simulate the numerics for the quantized weight in a quantized
model
"""
# suppress mypy warning
assert isinstance(self.weight_scale, torch.Tensor)
assert isinstance(self.weight_zero_point, torch.Tensor)
assert isinstance(self.weight_axis, torch.Tensor)
return _quantize_and_dequantize_weight(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
self.weight_dtype,
self.weight_scale,
self.weight_zero_point, self.weight_axis)
def get_quantized_weight(self):
# suppress mypy warning
assert isinstance(self.weight_scale, torch.Tensor)
assert isinstance(self.weight_zero_point, torch.Tensor)
assert isinstance(self.weight_axis, torch.Tensor)
return _quantize_weight(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
self.weight_axis)
def _save_to_state_dict(self, destination, prefix, keep_vars):
super()._save_to_state_dict(destination, prefix, keep_vars)
_save_weight_qparams(
destination, prefix, self.weight_qscheme, self.weight_dtype,
self.weight_scale, self.weight_zero_point, self.weight_axis)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
for key in _get_weight_qparam_keys(state_dict, prefix):
setattr(self, key, state_dict[prefix + key])
state_dict.pop(prefix + key)
super()._load_from_state_dict(
state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
def _quantize_weight(
weight: torch.Tensor,
weight_qscheme: torch.qscheme,
weight_dtype: torch.dtype,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
weight_axis: torch.Tensor):
if weight_dtype == torch.float16:
weight = weight.to(weight_dtype)
return weight
if weight_qscheme == torch.per_tensor_affine:
if weight_dtype in [torch.quint8, torch.qint8, torch.qint32]:
weight = torch.quantize_per_tensor(weight, weight_scale, weight_zero_point, weight_dtype)
return weight
elif weight_qscheme in [torch.per_channel_affine, torch.per_channel_affine_float_qparams]:
if weight_dtype in [torch.quint8, torch.qint8, torch.quint4x2, torch.qint32]:
weight = torch.quantize_per_channel(
weight, weight_scale,
weight_zero_point, weight_axis.item(), weight_dtype) # type: ignore[arg-type]
return weight
raise Exception(f"Unsupported dtype and qscheme: {weight_dtype}, {weight_qscheme}")
def _quantize_and_dequantize_weight(
weight: torch.Tensor,
weight_qscheme: torch.qscheme,
weight_dtype: torch.dtype,
weight_scale: torch.Tensor,
weight_zero_point: torch.Tensor,
weight_axis: torch.Tensor):
""" Quantize and then dequantize the weight based on
the quantization parameters
"""
if weight_qscheme in [
torch.per_tensor_affine,
torch.per_channel_affine,
torch.per_channel_affine_float_qparams]:
weight_quant = _quantize_weight(
weight, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis)
weight_dequant = weight_quant.dequantize()
else:
weight_dequant = weight
return weight_dequant
def _save_weight_qparams(destination, prefix, weight_qscheme, weight_dtype, weight_scale, weight_zero_point, weight_axis):
destination[prefix + "weight_qscheme"] = weight_qscheme
destination[prefix + "weight_dtype"] = weight_dtype
if weight_qscheme is not None:
destination[prefix + "weight_scale"] = weight_scale
destination[prefix + "weight_zero_point"] = weight_zero_point
if weight_qscheme == torch.per_channel_affine:
destination[prefix + "weight_axis"] = weight_axis
def _get_weight_qparam_keys(
state_dict: typing.Dict[str, typing.Any],
prefix: str):
keys = ["weight_qscheme", "weight_dtype"]
weight_qscheme = state_dict[prefix + "weight_qscheme"]
if weight_qscheme is not None:
keys.append("weight_scale")
keys.append("weight_zero_point")
if weight_qscheme == torch.quantize_per_channel:
keys.append("weight_axis")
return keys
|