1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
|
# Owner(s): ["oncall: quantization"]
import torch
from torch.testing._internal.common_utils import TestCase
from torch.ao.quantization.utils import get_fqn_to_example_inputs
from torch.nn.quantized.modules.utils import _quantize_weight
from torch.ao.quantization import MovingAverageMinMaxObserver, MovingAveragePerChannelMinMaxObserver
class TestUtils(TestCase):
def _test_get_fqn_to_example_inputs(self, M, example_inputs, expected_fqn_to_dim):
m = M().eval()
fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs)
for fqn, expected_dims in expected_fqn_to_dim.items():
assert fqn in expected_fqn_to_dim
example_inputs = fqn_to_example_inputs[fqn]
for example_input, expected_dim in zip(example_inputs, expected_dims):
assert example_input.dim() == expected_dim
def test_get_fqn_to_example_inputs_simple(self):
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
self.sub = Sub()
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.sub(x)
return x
expected_fqn_to_dim = {
"": (2,),
"linear1": (2,),
"linear2": (2,),
"sub": (2,),
"sub.linear1": (2,),
"sub.linear2": (2,)
}
example_inputs = (torch.rand(1, 5),)
self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim)
def test_get_fqn_to_example_inputs_default_kwargs(self):
""" Test that we can get example inputs for functions with default keyword arguments
"""
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
def forward(self, x, key1=torch.rand(1), key2=torch.rand(1)):
x = self.linear1(x)
x = self.linear2(x)
return x
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
self.sub = Sub()
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
# only override `key2`, `key1` will use default
x = self.sub(x, key2=torch.rand(1, 2))
return x
expected_fqn_to_dim = {
"": (2,),
"linear1": (2,),
"linear2": (2,),
# second arg is `key1`, which is using default argument
# third arg is `key2`, override by callsite
"sub": (2, 1, 2),
"sub.linear1": (2,),
"sub.linear2": (2,)
}
example_inputs = (torch.rand(1, 5),)
self._test_get_fqn_to_example_inputs(M, example_inputs, expected_fqn_to_dim)
def test_get_fqn_to_example_inputs_complex_args(self):
""" Test that we can record complex example inputs such as lists and dicts
"""
class Sub(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
def forward(self, x, list_arg, dict_arg):
x = self.linear1(x)
x = self.linear2(x)
return x
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(5, 5)
self.linear2 = torch.nn.Linear(5, 5)
self.sub = Sub()
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.sub(x, [x], {"3": x})
return x
example_inputs = (torch.rand(1, 5),)
m = M().eval()
fqn_to_example_inputs = get_fqn_to_example_inputs(m, example_inputs)
assert "sub" in fqn_to_example_inputs
assert isinstance(fqn_to_example_inputs["sub"][1], list)
assert isinstance(fqn_to_example_inputs["sub"][2], dict) and \
"3" in fqn_to_example_inputs["sub"][2]
def test_quantize_weight_clamping_per_tensor(self):
""" Test quant_{min, max} from per tensor observer is honored by `_quantize_weight` method
"""
fp_min, fp_max = -1000.0, 1000.0
q8_min, q8_max = -10, 10
float_tensor = torch.tensor([fp_min, fp_max])
observer = MovingAverageMinMaxObserver(
averaging_constant=1.0,
dtype=torch.qint8,
quant_min=q8_min,
quant_max=q8_max,
qscheme=torch.per_tensor_symmetric,
)
observer(float_tensor)
assert observer.min_val == fp_min
assert observer.max_val == fp_max
quantized_tensor = _quantize_weight(float_tensor, observer)
assert quantized_tensor.int_repr().max().item() == q8_max
assert quantized_tensor.int_repr().min().item() == q8_min
# Actual weight values can be outside than observer [min_val, max_val] for the moving average observer
float_tensor *= 1.2
quantized_tensor = _quantize_weight(float_tensor, observer)
assert quantized_tensor.int_repr().max().item() == q8_max
assert quantized_tensor.int_repr().min().item() == q8_min
def test_quantize_weight_clamping_per_channel(self):
""" Test quant_{min, max} from per channel observer is honored by `_quantize_weight` method
"""
fp_min, fp_max = -1000.0, 1000.0
q8_min, q8_max = -10, 10
float_tensor = torch.tensor([[fp_min, fp_max]])
observer = MovingAveragePerChannelMinMaxObserver(
averaging_constant=1.0,
dtype=torch.qint8,
quant_min=q8_min,
quant_max=q8_max,
qscheme=torch.per_channel_symmetric,
ch_axis=0,
)
observer(float_tensor)
assert observer.min_val == fp_min
assert observer.max_val == fp_max
quantized_tensor = _quantize_weight(float_tensor, observer)
assert quantized_tensor.int_repr().max().item() == q8_max
assert quantized_tensor.int_repr().min().item() == q8_min
# Actual weight values can be outside than observer [min_val, max_val] for the moving average observer
float_tensor *= 1.2
quantized_tensor = _quantize_weight(float_tensor, observer)
assert quantized_tensor.int_repr().max().item() == q8_max
assert quantized_tensor.int_repr().min().item() == q8_min
|