1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
|
# Owner(s): ["oncall: quantization"]
import copy
from typing import Any, Dict, Optional, Tuple
import torch
from torch._higher_order_ops.out_dtype import out_dtype # noqa: F401
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer import Quantizer
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from torch.export import export_for_training
from torch.testing._internal.common_quantization import (
NodeSpec as ns,
QuantizationTestCase,
skipIfNoQNNPACK,
TestHelperModules,
)
@skipIfNoQNNPACK
class TestPT2ERepresentation(QuantizationTestCase):
def _test_representation(
self,
model: torch.nn.Module,
example_inputs: Tuple[Any, ...],
quantizer: Quantizer,
ref_node_occurrence: Dict[ns, int],
non_ref_node_occurrence: Dict[ns, int],
fixed_output_tol: Optional[float] = None,
output_scale_idx: int = 2,
) -> torch.nn.Module:
# resetting dynamo cache
torch._dynamo.reset()
model = export_for_training(
model,
example_inputs,
).module()
model_copy = copy.deepcopy(model)
model = prepare_pt2e(model, quantizer)
# Calibrate
model(*example_inputs)
model = convert_pt2e(model, use_reference_representation=True)
self.checkGraphModuleNodes(model, expected_node_occurrence=ref_node_occurrence)
# make sure it runs
pt2e_quant_output = model(*example_inputs)
# TODO: torchdynamo times out when we do this, we can enable numerical checking
# after that is fixed
model_copy = prepare_pt2e(model_copy, quantizer)
# Calibrate
model_copy(*example_inputs)
model_copy = convert_pt2e(model_copy, use_reference_representation=False)
self.checkGraphModuleNodes(
model_copy, expected_node_occurrence=non_ref_node_occurrence
)
pt2e_quant_output_copy = model_copy(*example_inputs)
output_tol = None
if fixed_output_tol is not None:
output_tol = fixed_output_tol
else:
idx = 0
for n in model_copy.graph.nodes:
if (
n.target
== torch.ops.quantized_decomposed.quantize_per_tensor.default
):
idx += 1
if idx == output_scale_idx:
output_tol = n.args[1]
assert output_tol is not None
# make sure the result is off by one at most in the quantized integer representation
self.assertTrue(
torch.max(torch.abs(pt2e_quant_output_copy - pt2e_quant_output))
<= (2 * output_tol + 1e-5)
)
def test_static_linear(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 5),)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
)
def test_dynamic_linear(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=False, is_dynamic=True
)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(2, 5),)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
fixed_output_tol=1e-4,
)
def test_conv2d(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv2d = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv2d(x)
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=False)
quantizer.set_global(operator_config)
example_inputs = (torch.randn(1, 3, 3, 3),)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
)
def test_add(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
return x + y
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
m_eager = M().eval()
example_inputs = (
torch.randn(1, 3, 3, 3),
torch.randn(1, 3, 3, 3),
)
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
)
def test_add_relu(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
out = x + y
out = torch.nn.functional.relu(out)
return out
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
example_inputs = (
torch.randn(1, 3, 3, 3),
torch.randn(1, 3, 3, 3),
)
ref_node_occurrence = {
ns.call_function(out_dtype): 2,
}
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence=ref_node_occurrence,
non_ref_node_occurrence={},
)
def test_maxpool2d(self):
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = TestHelperModules.ConvMaxPool2d().eval()
example_inputs = (torch.randn(1, 2, 2, 2),)
self._test_representation(
m_eager,
example_inputs,
quantizer,
ref_node_occurrence={},
non_ref_node_occurrence={},
)
def test_qdq_per_channel(self):
"""Test representation for quantize_per_channel and dequantize_per_channel op"""
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(5, 5)
def forward(self, x):
return self.linear(x)
quantizer = XNNPACKQuantizer()
# use per channel quantization for weight
operator_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(operator_config)
m_eager = M().eval()
inputs = [
(torch.randn(1, 5),),
(torch.randn(1, 3, 5),),
(torch.randn(1, 3, 3, 5),),
(torch.randn(1, 3, 3, 3, 5),),
]
for example_inputs in inputs:
ref_node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel.default
): 0,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 0,
}
non_ref_node_occurrence = {
# quantize_per_channel is folded
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_channel.default
): 0,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_channel.default
): 1,
}
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence,
non_ref_node_occurrence,
output_scale_idx=2,
)
def test_qdq(self):
"""Test representation for quantize and dequantize op"""
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
return x + y
quantizer = XNNPACKQuantizer()
quantization_config = get_symmetric_quantization_config(is_per_channel=True)
quantizer.set_global(quantization_config)
m_eager = M().eval()
example_inputs = (
torch.randn(1, 3, 3, 3),
torch.randn(1, 3, 3, 3),
)
ref_node_occurrence = {
ns.call_function(torch.ops.quantized_decomposed.quantize_per_tensor): 0,
ns.call_function(torch.ops.quantized_decomposed.dequantize_per_tensor): 0,
}
non_ref_node_occurrence = {
ns.call_function(
torch.ops.quantized_decomposed.quantize_per_tensor.default
): 3,
ns.call_function(
torch.ops.quantized_decomposed.dequantize_per_tensor.default
): 3,
}
self._test_representation(
M().eval(),
example_inputs,
quantizer,
ref_node_occurrence,
non_ref_node_occurrence,
)
|