1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
|
# -*- coding: utf-8 -*-
# Owner(s): ["oncall: quantization"]
import torch
import torch._C_flatbuffer
from torch.ao.quantization import (
default_dynamic_qconfig,
per_channel_dynamic_qconfig,
)
from torch.ao.quantization.quantize_jit import (
prepare_dynamic_jit,
convert_dynamic_jit,
_prepare_ondevice_dynamic_jit,
_quantize_ondevice_dynamic_jit,
)
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_quantization import (
get_script_module,
LinearAddModel,
)
from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule
from torch.testing import FileCheck
from torch.utils import bundled_inputs as bundled_inputs
import io
from typing import Dict
class myMod(torch.nn.Module):
def __init__(self, weight):
super(myMod, self).__init__()
self.fc1 = torch.nn.Linear(5, 5).float()
self.fc1.weight = weight
self.fc2 = torch.nn.Linear(5, 5).float()
def forward(self, x):
return self.fc2(self.fc1(x))
class MyConvLinearModule(torch.nn.Module):
def __init__(self):
super(MyConvLinearModule, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
weight = torch.nn.Parameter(torch.ones(5, 5))
self.weight1 = torch.nn.Parameter(torch.ones(5, 5))
self.mymod = myMod(weight)
def forward(self, x):
conv_output = self.conv(x)
y = self.mymod(conv_output)
z = torch.nn.functional.linear(y, self.weight1)
return z
def get_example_inputs(self):
return (torch.rand(1, 3, 12, 7),)
class OnDevicePTQUtils(object):
observer_module_name = ['MinMaxObserver', 'PerChannelMinMaxObserver']
@staticmethod
def insert_observers(model, qconfig_dict):
inputs = model.get_example_inputs()
scripted_model = get_script_module(model, False, inputs)
scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict)
return scripted_model
@staticmethod
def ptq_dynamic_quantize(model, qconfig_dict):
inputs = model.get_example_inputs()
m = get_script_module(model, False, inputs)
m = _quantize_ondevice_dynamic_jit(m, qconfig_dict, 'forward', True)
return m
@staticmethod
def find_observer_modules(m):
observer_modules = []
for child_module in m.children():
if child_module.original_name in OnDevicePTQUtils.observer_module_name:
observer_modules.append(child_module)
return observer_modules
@staticmethod
def is_value_type_observer(value):
type_name = value.type()
for observer_type in OnDevicePTQUtils.observer_module_name:
if observer_type in type_name.str():
return True
return False
@staticmethod
def is_calculate_qparam(node):
if node.kind() == "prim::CallMethod":
if node.s('name') == "calculate_qparams":
return True
return False
@staticmethod
def get_linear_packed_param_fp_weight(node):
weight = node.inputsAt(0).node()
if weight.kind() != "aten::quantize_per_tensor" and weight.kind() != "aten::quantize_per_channel":
raise ValueError("Quantized weight must be produced.")
fp_weight = weight.inputsAt(0).node()
assert fp_weight.kind() == "prim::GetAttr", "Weight must be an attribute of the module."
fp_weight_name = fp_weight.s('name')
return fp_weight_name
@staticmethod
def is_per_channel_quantized_packed_param(node):
assert node.kind() == 'quantized::linear_prepack', "Node must corresponds to linear_prepack."
weight = node.inputsAt(0).node()
assert weight.kind() != "aten::quantize_per_tensor" or weight.kind() != "aten::quantize_per_channel"
return weight.kind() != "aten::quantize_per_tensor"
class TestOnDeviceDynamicPTQInsertObservers(TestCase):
def _check_num_and_type_of_observers(self, model, num_observers):
qconfig_dict = {"": default_dynamic_qconfig}
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
self.assertTrue(len(observer_modules) == num_observers)
for observer in observer_modules:
self.assertTrue(observer.original_name == 'MinMaxObserver')
qconfig_dict = {"": per_channel_dynamic_qconfig}
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model)
self.assertTrue(len(observer_modules) == num_observers)
for observer in observer_modules:
self.assertTrue(observer.original_name == 'PerChannelMinMaxObserver')
def _check_observer_method(self, model, num_observers):
qconfig_dict = {"": default_dynamic_qconfig}
inputs = model.get_example_inputs()
orig_scripted_model = get_script_module(model, False, inputs)
torch._C._jit_pass_inline(orig_scripted_model.graph)
orig_forward_graph = orig_scripted_model.graph.str()
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
quant_forward_graph = scripted_model.graph.str()
# exact graph matching is difficult so just resorting to # of lines
# instead of implementing graph matching
self.assertEqual(len(orig_forward_graph.splitlines()), len(quant_forward_graph.splitlines()))
observe_method = scripted_model.observe_forward.graph
FileCheck().check_count("prim::CallMethod[name=\"forward\"](%_observer",
num_observers, exactly=True).run(observe_method)
reset_observers_method = scripted_model.reset_observers_forward.graph
FileCheck().check_count(
"prim::CallMethod[name=\"reset_min_max_vals\"](%_observer", num_observers, exactly=True).run(reset_observers_method)
def _observer_is_weight_only(self, node):
if (node.kind() == "prim::CallMethod") and node.s("name") == "forward":
if (OnDevicePTQUtils.is_value_type_observer(node.inputsAt(0))):
return (node.inputsAt(1).node().kind() == "prim::GetAttr")
return False
def test_num_observers(self):
model = LinearAddModel()
self._check_num_and_type_of_observers(model, 2)
model = MyConvLinearModule()
self._check_num_and_type_of_observers(model, 3)
def test_observe_method(self):
model = MyConvLinearModule()
self._check_observer_method(model, 3)
def test_weight_only_observers(self):
model = MyConvLinearModule()
qconfig_dict = {"": default_dynamic_qconfig}
inputs = model.get_example_inputs()
scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict)
observe_forward_graph = scripted_model.observe_forward.graph
num_weight_only_observers = 0
for node in observe_forward_graph.nodes():
if (self._observer_is_weight_only(node)):
num_weight_only_observers += 1
self.assertEqual(num_weight_only_observers, 3)
class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase):
def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0):
quantize_forward_graph = model.quantize_forward.graph
quantize_per_tensor = quantize_per_channel = 0
for n in quantize_forward_graph.nodes():
if "aten::quantize_per_tensor" in n.kind():
quantize_per_tensor += 1
if "aten::quantize_per_channel" in n.kind():
quantize_per_channel += 1
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
def _validate_calculate_qparams(self, model, num_nodes):
quantize_forward_graph = model.quantize_forward.graph
num_calculate_qparams = 0
for n in quantize_forward_graph.nodes():
if OnDevicePTQUtils.is_calculate_qparam(n):
num_calculate_qparams += 1
self.assertEqual(num_calculate_qparams, num_nodes)
def _validate_no_observer_forward(self, model):
quantize_forward_graph = model.quantize_forward.graph
for n in quantize_forward_graph.nodes():
if (n.kind() == "prim::CallMethod") and n.s("name") == "forward":
if (OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0))):
return False
return True
def _check_quant_dequant_and_calc_qparams(self, model, num_nodes):
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quant_dequant_nodes(m, num_nodes)
self._validate_calculate_qparams(m, num_nodes)
self._validate_no_observer_forward(m)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quant_dequant_nodes(m, num_nodes, num_nodes)
self._validate_calculate_qparams(m, num_nodes)
self._validate_no_observer_forward(m)
def _check_quantize_forward_runs(self, model):
inputs = model.get_example_inputs()
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
# First must run observe forward to record the stats to produce
# correct scales and zero points
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
def test_num_quant_dequant_nodes(self):
model = LinearAddModel()
self._check_quant_dequant_and_calc_qparams(model, 2)
model = MyConvLinearModule()
self._check_quant_dequant_and_calc_qparams(model, 3)
def test_quantize_forward_runs(self):
model = LinearAddModel()
self._check_quantize_forward_runs(model)
model = MyConvLinearModule()
self._check_quantize_forward_runs(model)
class TestOnDeviceDynamicPTQFinalize(TestCase):
def _validate_packed_params(self, model, num_nodes, per_channel=0):
quantize_forward_graph = model.quantize_forward.graph
quantize_per_tensor = quantize_per_channel = 0
linear_prepack = 0
linear_prepack_uses = 0
for n in quantize_forward_graph.nodes():
if n.kind() == 'prim::SetAttr':
maybe_packed_param_value = n.inputsAt(1)
maybe_packed_param = maybe_packed_param_value.node()
if maybe_packed_param.kind() == 'quantized::linear_prepack':
linear_prepack += 1
linear_prepack_uses += len(maybe_packed_param_value.uses())
if OnDevicePTQUtils.is_per_channel_quantized_packed_param(maybe_packed_param):
quantize_per_channel += 1
else:
quantize_per_tensor += 1
self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes)
self.assertEqual(quantize_per_channel, per_channel)
self.assertEqual(linear_prepack, num_nodes)
self.assertEqual(linear_prepack_uses, num_nodes)
def _validate_no_linear_unpack(self, model):
quantize_forward_graph = model.quantize_forward.graph
for n in quantize_forward_graph.nodes():
if n.kind() == 'quantized::linear_unpack':
return False
return True
def _validate_setattr_fp_weights(self, model, num_nodes):
quantize_forward_graph = model.quantize_forward.graph
fp_weights_setattr = 0
fp_weight_names = []
for n in quantize_forward_graph.nodes():
if n.kind() == 'prim::SetAttr':
maybe_packed_param = n.inputsAt(1).node()
if maybe_packed_param.kind() == 'quantized::linear_prepack':
weight_name = OnDevicePTQUtils.get_linear_packed_param_fp_weight(maybe_packed_param)
fp_weight_names.append(weight_name)
for n in quantize_forward_graph.nodes():
# This is basically detecting
# %x = prim::Constant
# = prim::SetAttr(<weight_name>)(module_value, x)
# Thus making sure that the original fp weights are
# reset
if n.kind() == 'prim::SetAttr':
weight_name = n.s('name')
if weight_name in fp_weight_names:
maybe_constant = n.inputsAt(1).node()
if maybe_constant.kind() == 'prim::Constant':
fp_weights_setattr += 1
self.assertEqual(fp_weights_setattr, num_nodes)
def _validate_quantized_forward(self, model, num_nodes):
quantized_forward_graph = model.quantized_forward.graph
quantize_per_tensor = quantize_per_channel = 0
quantized_linear_dynamic = 0
linear_packed_params = 0
num_setattr = 0
for n in quantized_forward_graph.nodes():
if "aten::quantize_per_tensor" in n.kind():
quantize_per_tensor += 1
if "aten::quantize_per_channel" in n.kind():
quantize_per_channel += 1
if "quantized::linear_dynamic" in n.kind():
quantized_linear_dynamic += 1
if n.kind() == 'prim::GetAttr':
output = n.outputsAt(0)
output_type = output.type()
if "LinearPackedParamsBase" in output_type.str():
linear_packed_params += 1
if n.kind() == 'prim::SetAttr':
num_setattr += 1
self.assertEqual(quantize_per_tensor, 0)
self.assertEqual(quantize_per_channel, 0)
self.assertEqual(quantized_linear_dynamic, num_nodes)
self.assertEqual(linear_packed_params, num_nodes)
# self.assertEqual(num_setattr, 0)
def _check_quantize_forward(self, model, num_nodes):
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_packed_params(m, num_nodes)
self._validate_no_linear_unpack(m)
self._validate_setattr_fp_weights(m, num_nodes)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_packed_params(m, num_nodes, num_nodes)
self._validate_no_linear_unpack(m)
self._validate_setattr_fp_weights(m, num_nodes)
def _check_quantized_forward(self, model, num_nodes):
qconfig_dict = {"" : default_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quantized_forward(m, num_nodes)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
self._validate_quantized_forward(m, num_nodes)
def _check_against_ref_dynamic_ptq(self, model):
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : default_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
ref_output = ref_m(*inputs)
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
thrown = False
try:
m(*inputs)
except Exception as e:
thrown = True
self.assertTrue(thrown)
# test with per channel quant
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
ref_output = ref_m(*inputs)
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
thrown = False
try:
m(*inputs)
except Exception as e:
thrown = True
self.assertTrue(thrown)
def _check_serdes_and_device_side_api_helper(self, model, check_device_side_api=False):
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : default_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
buffer = io.BytesIO()
torch.jit.save(ref_m, buffer)
buffer.seek(0)
ref_m = torch.jit.load(buffer)
ref_output = ref_m(*inputs)
if not check_device_side_api:
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
else:
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
first_input, = inputs
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
self.assertFalse(m.find_method("quantized_forward"))
self.assertFalse(m.find_method("quantize_forward"))
self.assertFalse(m.find_method("observe_forward"))
self.assertFalse(m.find_method("reset_observers_forward"))
output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
# Now serialize to flabuffer and load from fb and check
dict: Dict[str, str] = {}
bytes = torch._C_flatbuffer._save_mobile_module_to_bytes(m._c, dict)
m = LiteScriptModule(torch._C_flatbuffer._load_mobile_module_from_bytes(bytes))
fb_output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, fb_output))
model.eval()
inputs = model.get_example_inputs()
ref_m = torch.jit.script(model)
torch._C._jit_pass_inline(ref_m.graph)
qconfig_dict = {"" : per_channel_dynamic_qconfig}
ref_m = prepare_dynamic_jit(ref_m, qconfig_dict)
ref_m = convert_dynamic_jit(ref_m)
buffer = io.BytesIO()
torch.jit.save(ref_m, buffer)
buffer.seek(0)
ref_m = torch.jit.load(buffer)
ref_output = ref_m(*inputs)
if not check_device_side_api:
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
m = torch.jit.load(buffer)
m.reset_observers_forward()
m.observe_forward(*inputs)
m.quantize_forward(*inputs)
output = m.quantized_forward(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
else:
# check for lite interpreter
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
first_input, = inputs
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
buffer.seek(0)
m = _load_for_lite_interpreter(buffer) # Error here
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
self.assertFalse(m.find_method("quantized_forward"))
self.assertFalse(m.find_method("quantize_forward"))
self.assertFalse(m.find_method("observe_forward"))
self.assertFalse(m.find_method("reset_observers_forward"))
output = m(*inputs)
self.assertTrue(torch.allclose(ref_output, output))
def _check_serialization_deserialization(self, model):
self._check_serdes_and_device_side_api_helper(model, False)
def _check_device_side_api(self, model):
self._check_serdes_and_device_side_api_helper(model, True)
def test_quantize_forward(self):
model = LinearAddModel()
self._check_quantize_forward(model, 2)
model = MyConvLinearModule()
self._check_quantize_forward(model, 3)
def test_quantized_forward(self):
model = LinearAddModel()
self._check_quantized_forward(model, 2)
model = MyConvLinearModule()
self._check_quantized_forward(model, 3)
def test_against_offdevice_dynamic_ptq(self):
model = LinearAddModel()
self._check_against_ref_dynamic_ptq(model)
model = MyConvLinearModule()
self._check_against_ref_dynamic_ptq(model)
def test_serialization_deserialization(self):
model = MyConvLinearModule()
self._check_serialization_deserialization(model)
def test_device_side_api(self):
model = MyConvLinearModule()
self._check_device_side_api(model)
|