1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
import torch
import numpy as np
from torch.nn.quantized.modules.utils import WeightedQuantizedModule
from torch.ao.quantization.experimental.observer import APoTObserver
from torch.ao.quantization.experimental.quantizer import quantize_APoT
class LinearAPoT(WeightedQuantizedModule):
r"""
A quantized linear module with quantized tensor as inputs and outputs
to support APoT quantization.
We adopt the same interface as `torch.nn.Linear`, see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear for documentation.
Similar to :class:`~torch.nn.Linear`, attributes will be randomly
initialized at module creation time and will be overwritten later
Attributes:
alpha: `alpha` qparam of output Quantized Tensor, type: Tensor
gamma: `gamma` qparam of output Quantized Tensor, type: Tensor
quantization_levels: `quantization_levels` qparam of output Quantized Tensor, type: Tensor
level_indices: `level_indices` qparam of output Quantized Tensor, type: Tensor
weight: APoT quantized tensor from weight2quantize
weight_transposed: transposed weight tensor, used in linear transformation calculation (y = x * A^T + b)
"""
def __init__(self, weight2quantize: torch.Tensor, b: int, k: int):
assert weight2quantize.dim() == 2
assert b % k == 0
super().__init__()
self.b = b
self.k = k
self.n = self.b // self.k
observer = APoTObserver(b=self.b, k=self.k)
observer(weight2quantize)
self.alpha, self.gamma, self.quantization_levels, self.level_indices = observer.calculate_qparams(signed=False)
quantized_weight = quantize_APoT(weight2quantize, self.alpha, self.gamma, self.quantization_levels, self.level_indices)
self.weight = quantized_weight.data
self.weight_transposed = torch.transpose(self.weight, 0, 1)
def decompose_APoT(self, x):
r"""
Decompose binary representation of APoT values into list of k-sized blocks
Args:
x (Tensor): binary representation of APoT quantized tensor
"""
# remove "0b" prefix from binary representation
x = x[2:]
# initialize list of blocks
blocks = []
while x:
blocks.append(x[0:self.k])
x = x[self.k:]
return blocks
def bitshift_mul(self, weight_val, r):
r"""
Compute multiplication of weight_val * r using bitshifting
method discussed in APoT paper: https://arxiv.org/pdf/1909.13144.pdf
Args:
weight_val: list of binary digits representing APoT quantized weight value
r: int representing uniformly quantized activation value
"""
product = 0
idx = len(weight_val) - 1
place = 0
while idx >= 0:
block = weight_val[idx]
# reverse digits in block
block = block[::-1]
curr_block_result = 0
for ele in block:
if int(ele):
curr_block_result += r << place
place += 1
idx -= 1
product += curr_block_result
return product
def matmul(self, decomposed_weight, activation):
r"""
Perform matrix multiplication between decomposed_weight and
activation by calling bitshift_mul function for each value
Args:
decomposed_weight (Tensor): APoT quantized weight decomposed into binary
activation (Tensor): uniformly quantized activation
"""
rows1 = activation.size(dim=0)
cols1 = activation.size(dim=1)
rows2 = decomposed_weight.shape[0]
cols2 = decomposed_weight.shape[1]
result = torch.zeros(rows1, cols2)
# compute matrix multiplication with bitshifts
for i in range(rows1):
for j in range(cols2):
for k in range(rows2):
weight_val = decomposed_weight[k][j]
r = int(activation[i][k])
product = self.bitshift_mul(weight_val, r)
result[i][j] += product
return result
def forward(self, activation: torch.Tensor) -> torch.FloatTensor:
r"""
Multiply APoT quantized weight and uniformly quantized activation (dtype: quint8)
with bitshifting instead of matrix multiplication.
Result has dtype torch.float32
Args:
activation (Tensor): uniformly quantized activation tensor
"""
assert activation.dim() == 2
weight_rows = self.weight_transposed.size()[0]
weight_cols = self.weight_transposed.size()[1]
decomposed_weight = np.empty(shape=(weight_rows, weight_cols), dtype=object)
for row in range(weight_rows):
for col in range(weight_cols):
decomposed_weight[row][col] = self.decompose_APoT(bin(self.weight_transposed[row][col]))
result = self.matmul(decomposed_weight, activation).type(torch.FloatTensor)
return result
@classmethod
def from_reference(cls, # type: ignore[override]
ref_qlinear,
alpha: torch.Tensor,
gamma: torch.Tensor,
quantization_levels: torch.Tensor,
level_indices: torch.Tensor):
raise NotImplementedError
|