1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
"""
This module implements nonuniform observers used to collect statistics about
the values observed during calibration (PTQ) or training (QAT).
"""
import torch
import itertools
import matplotlib.pyplot as plt
from torch.ao.quantization.observer import ObserverBase
from torch.ao.quantization.experimental.apot_utils import float_to_apot, apot_to_float
# TODO: Consider adding NonUniformQuantizationObserverBase class
# when more than one non-uniform method is implemented
class APoTObserver(ObserverBase):
b: int
k: int
n: int
min_val: torch.Tensor
max_val: torch.Tensor
def __init__(
self,
b,
k,
dtype=torch.quint8) -> None:
super().__init__(dtype)
self.b = b
self.k = k
self.min_val = torch.tensor([])
self.max_val = torch.tensor([])
# min_val and max_val are optional args to override
# the min_val and max_val observed by forward
def calculate_qparams(self, signed):
return self._calculate_qparams(signed, self.min_val, self.max_val)
r""" Calculates nonuniform quantization parameters according to APoT paper:
https://arxiv.org/pdf/1909.13144.pdf.
Arg:
signed: specifies whether to include signed values in quantization level calculations
min_val: optional arg that can override min_val internal attribute
max_val: optional arg that can override max_val internal attribute
Returns:
alpha: alpha quantization parameter, max of abs value of observed values
gamma: gamma quantization parameter, defined to ensure that alpha is the maximum of the range
quantization_levels: non-uniform quantization levels (fp representation)
level_indices: int representation of quantization_levels indices
"""
def _calculate_qparams(self, signed: bool, min_val=None, max_val=None):
if min_val is not None:
self.min_val = min_val
if max_val is not None:
self.max_val = max_val
# compute alpha
alpha = torch.max(-self.min_val, self.max_val)
# check for valid inputs of b, k
assert(self.k and self.k != 0)
assert(self.b % self.k == 0)
# compute n and store as member variable
self.n = self.b // self.k
# store a tensor of subtensors (all levels)
p_all = []
# create levels
for i in range(0, self.n):
p_curr = torch.tensor([0])
for j in range(0, (2 ** self.k - 2) + 1):
curr_ele = 2 ** (- (i + j * self.n))
p_append = torch.tensor([curr_ele])
p_curr = torch.cat((p_curr, p_append))
# introduce signed numbers
if signed:
p_curr = torch.cat((p_curr, torch.tensor([-curr_ele])))
if signed:
# sort tensor in reverse order before adding to list if signed
sorted, indices = torch.sort(p_curr, descending=True)
p_all.append(sorted)
else:
p_all.append(p_curr)
# gamma calculation:
# loop through all tensors
# if signed, add element at index 0 for each tensor
# else, add element at index 1 for each tensor
# gamma defined to ensure alpha is at max of range
p_sum = 0.0
for tens in p_all:
if signed:
p_sum += float(tens[0])
else:
p_sum += float(tens[1])
# assign gamma
gamma = alpha / p_sum
# calculate cartesian product
cartesian_product = list(itertools.product(*p_all))
quantization_levels_list = []
# calculate sum of each row
for row in cartesian_product:
sum = 0.0
for ele in row:
sum += ele
quantization_levels_list.append(sum)
quantization_levels_gamma = [float(gamma) * ele for ele in quantization_levels_list]
quantization_levels = torch.tensor(quantization_levels_gamma)
level_indices = torch.tensor([])
quantization_levels, level_indices = quantization_levels.sort()
return (alpha, gamma, quantization_levels, level_indices)
r"""Records the running minimum and maximum of ``x``.
Args:
x_orig: Tensor to be observed for min and max val"""
def forward(self, x_orig):
if x_orig.numel() == 0:
return x_orig
x = x_orig.detach()
min_val, max_val = torch.aminmax(x)
if self.min_val.numel():
min_val = torch.min(min_val, self.min_val)
if self.max_val.numel():
max_val = torch.max(max_val, self.max_val)
self.min_val = min_val
self.max_val = max_val
return x_orig
r"""Displays visualization of APoT quantization levels
Args:
observer: APoTObserver to calculate qparams
signed: bool to indicate if qparams should be signed/unsigned
"""
def quant_levels_visualization(self, signed=False):
alpha, gamma, quantization_levels, level_indices = self.calculate_qparams(signed)
xs = [float(x) / 1000.0 for x in range(1000)]
ys = [apot_to_float(float_to_apot(x, quantization_levels, level_indices, alpha),
quantization_levels, level_indices).item() for x in xs]
f = plt.figure(figsize=(15, 10))
plt.plot(xs, ys)
plt.title("APoT Quantization Plot")
plt.xlabel("Full Precision")
plt.ylabel("Quantized")
plt.show()
|