1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
|
import torch
from torch.nn import Module
from .observer import MovingAverageMinMaxObserver, HistogramObserver, MovingAveragePerChannelMinMaxObserver, _with_args
import re
class FakeQuantize(Module):
r""" Simulate the quantize and dequantize operations in training time.
The output of this module is given by
x_out = (clamp(round(x/scale + zero_point), quant_min, quant_max)-zero_point)*scale
* :attr:`scale` defines the scale factor used for quantization.
* :attr:`zero_point` specifies the quantized value to which 0 in floating point maps to
* :attr:`quant_min` specifies the minimum allowable quantized value.
* :attr:`quant_max` specifies the maximum allowable quantized value.
* :attr:`fake_quant_enable` controls the application of fake quantization on tensors, note that
statistics can still be updated.
* :attr:`observer_enable` controls statistics collection on tensors
* :attr:`dtype` specifies the quantized dtype that is being emulated with fake-quantization,
allowable values are torch.qint8 and torch.quint8. The values of quant_min and
quant_max should be chosen to be consistent with the dtype
Args:
observer (module): Module for observing statistics on input tensors and calculating scale
and zero-point.
quant_min (int): The minimum allowable quantized value.
quant_max (int): The maximum allowable quantized value.
observer_kwargs (optional): Arguments for the observer module
Attributes:
observer (Module): User provided module that collects statistics on the input tensor and
provides a method to calculate scale and zero-point.
"""
def __init__(self, observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255, **observer_kwargs):
super(FakeQuantize, self).__init__()
assert quant_min <= quant_max, \
'quant_min must be less than or equal to quant_max'
self.quant_min = quant_min
self.quant_max = quant_max
# fake_quant_enabled and observer_enabled are buffers to support their
# replication in DDP. Data type is uint8 because NCCL does not support
# bool tensors.
self.register_buffer('fake_quant_enabled', torch.tensor([1], dtype=torch.uint8))
self.register_buffer('observer_enabled', torch.tensor([1], dtype=torch.uint8))
self.activation_post_process = observer(**observer_kwargs)
assert torch.iinfo(self.activation_post_process.dtype).min <= quant_min, 'quant_min out of bound'
assert quant_max <= torch.iinfo(self.activation_post_process.dtype).max, 'quant_max out of bound'
self.register_buffer('scale', torch.tensor([1.0]))
self.register_buffer('zero_point', torch.tensor([0]))
self.dtype = self.activation_post_process.dtype
self.qscheme = self.activation_post_process.qscheme
self.ch_axis = self.activation_post_process.ch_axis \
if hasattr(self.activation_post_process, 'ch_axis') else -1
@torch.jit.export
def enable_fake_quant(self, enabled=True):
# type: (bool) -> None
self.fake_quant_enabled[0] = 1 if enabled else 0
@torch.jit.export
def disable_fake_quant(self):
self.enable_fake_quant(False)
@torch.jit.export
def enable_observer(self, enabled=True):
# type: (bool) -> None
self.observer_enabled[0] = 1 if enabled else 0
@torch.jit.export
def disable_observer(self):
self.enable_observer(False)
@torch.jit.export
def calculate_qparams(self):
return self.activation_post_process.calculate_qparams()
def forward(self, X):
if self.observer_enabled[0] == 1:
self.activation_post_process(X.detach())
_scale, _zero_point = self.calculate_qparams()
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
self.scale.resize_(_scale.shape)
self.scale.copy_(_scale)
self.zero_point.resize_(_zero_point.shape)
self.zero_point.copy_(_zero_point)
if self.fake_quant_enabled[0] == 1:
if self.qscheme == torch.per_channel_symmetric or self.qscheme == torch.per_channel_affine:
X = torch.fake_quantize_per_channel_affine(X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(X, float(self.scale),
int(self.zero_point), self.quant_min,
self.quant_max)
return X
with_args = classmethod(_with_args)
@torch.jit.export
def extra_repr(self):
return 'fake_quant_enabled={}, observer_enabled={},\
quant_min={}, quant_max={}, dtype={}, qscheme={}, ch_axis={}, \
scale={}, zero_point={}'.format(
self.fake_quant_enabled, self.observer_enabled,
self.quant_min, self.quant_max,
self.dtype, self.qscheme, self.ch_axis, self.scale, self.zero_point)
def _save_to_state_dict(self, destination, prefix, keep_vars):
# We cannot currently register scalar values as buffers, so need to manually
# specify serialization here.
super(FakeQuantize, self)._save_to_state_dict(destination, prefix, keep_vars)
destination[prefix + 'scale'] = self.scale
destination[prefix + 'zero_point'] = self.zero_point
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
# Removing this function throws an error that the the size of the loaded tensor does not match the original size
# i.e., These buffers start out with numel 0 and become numel 1 once they have their first forward pass.
local_state = ['scale', 'zero_point']
for name in local_state:
key = prefix + name
if key in state_dict:
val = state_dict[key]
setattr(self, name, val)
elif strict:
missing_keys.append(key)
super(FakeQuantize, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
default_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=0, quant_max=255,
dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=True)
default_weight_fake_quant = FakeQuantize.with_args(observer=MovingAverageMinMaxObserver, quant_min=-128, quant_max=127,
dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False)
default_per_channel_weight_fake_quant = FakeQuantize.with_args(observer=MovingAveragePerChannelMinMaxObserver,
quant_min=-128,
quant_max=127,
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric,
reduce_range=False,
ch_axis=0)
default_histogram_fake_quant = FakeQuantize.with_args(observer=HistogramObserver,
quant_min=0,
quant_max=255,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
reduce_range=True)
def _is_fake_quant_script_module(mod):
''' Returns true if given mod is an instance of FakeQuantize script module.
'''
if isinstance(mod, torch.jit.RecursiveScriptModule):
# qualified name looks like '__torch__.torch.quantization.fake_quantize.___torch_mangle_2.FakeQuantize'
suffix = mod._c.qualified_name.split('.', 1)[1]
name = re.sub(r'\.___torch_mangle_\d+', '', suffix)
return name == 'torch.quantization.fake_quantize.FakeQuantize'
return False
def disable_fake_quant(mod):
if type(mod) == FakeQuantize or _is_fake_quant_script_module(mod):
mod.disable_fake_quant()
def enable_fake_quant(mod):
if type(mod) == FakeQuantize or _is_fake_quant_script_module(mod):
mod.enable_fake_quant()
def disable_observer(mod):
if type(mod) == FakeQuantize or _is_fake_quant_script_module(mod):
mod.disable_observer()
def enable_observer(mod):
if type(mod) == FakeQuantize or _is_fake_quant_script_module(mod):
mod.enable_observer()
|