1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951
|
# mypy: allow-untyped-defs
import operator
import warnings
from collections import namedtuple
from typing import Any, Dict, List, Optional, Tuple
import torch
import torch.ao.nn.intrinsic as nni
import torch.nn as nn
import torch.nn.functional as F
from torch.ao.quantization.fx.graph_module import _get_observed_graph_module_attr
from torch.ao.quantization.observer import (
_with_args,
ObserverBase,
PerChannelMinMaxObserver,
)
from torch.ao.quantization.utils import _parent_name, check_min_max_valid
from torch.fx import GraphModule
from torch.fx.graph import Node
from .utils import (
get_new_attr_name_with_prefix,
maybe_get_next_module,
node_arg_is_weight,
)
CUSTOM_MODULE_SUPP_LIST: List[Any] = []
def reshape_scale(scale: torch.Tensor, axis: int, input: torch.Tensor) -> torch.Tensor:
"""Reshapes the scale so that we can multiply it to the input by the given axis."""
new_shape = [1] * input.ndim
new_shape[axis] = input.size(axis)
return scale.view(new_shape)
qsheme_mapping_per_tensor_to_per_channel = {
torch.per_tensor_affine: torch.per_channel_affine,
torch.per_tensor_symmetric: torch.per_channel_symmetric,
}
class _InputEqualizationObserver(nn.Module):
r"""Observer for tracking the running min/max values of input columns, and
computing the quantization parameters for the overall min/max input values.
Args:
dtype: Quantized data type
qscheme: Quantization scheme
quant_min: Minimum quantization value. If unspecified, it will
follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will
follow the 8-bit setup.
The running minimum/maximum :math:`x_\text{min/max}` are computed in the
same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`,
with the difference that the running min/max values are stored per column.
This observer is intended to be used along with a WeightEqualizationObserver
to calculate the equalization scale.
"""
def __init__(
self,
dtype=torch.quint8,
qscheme=torch.per_tensor_affine,
quant_min=None,
quant_max=None,
factory_kwargs=None,
) -> None:
super().__init__()
if qscheme not in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
raise TypeError("Input qscheme must be per-tensor")
self.dtype = dtype
self.qscheme = qscheme
per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
self.input_obs = PerChannelMinMaxObserver(
ch_axis=1,
dtype=dtype,
qscheme=per_channel_qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs,
)
self.equalization_scale = torch.tensor(1)
self.equalization_shape: List[int] = []
def forward(self, x_orig):
if not (x_orig.ndim >= 2 and x_orig.ndim <= 5):
raise ValueError(
"InputEqualizationObserver only supports Linear and Conv layers"
)
# Calculate the shape needed to reshape the equalization scale later (needed for Conv layers)
self.equalization_shape = [1] * x_orig.ndim
self.equalization_shape[1] = x_orig.size(1)
return self.input_obs(x_orig)
def get_input_minmax(self):
return (self.input_obs.min_val, self.input_obs.max_val)
def set_equalization_scale(self, equalization_scale):
# Reshape the equalization scale along axis=1 so that it can be
# multiplied with the input along axis=1
if equalization_scale.nelement() == 1 and equalization_scale == torch.tensor(1):
return
self.equalization_scale = torch.reshape(
equalization_scale, self.equalization_shape
)
def calculate_scaled_minmax(self):
r"""Returns the scaled min/max inputs"""
if (
self.equalization_scale.nelement() == 1
and self.equalization_scale == torch.tensor(1)
):
warnings.warn(
"Must call calculate_equalization_scale before calling calculate_scaled_minmax. "
+ "Will not scale the next quantization observer."
)
return None, None
# Calculate qparams for the scaled min/max inputs
# Scale the input by the equalization scale located at the same column
# index
(min_inputs, max_inputs) = self.get_input_minmax()
equalization_scale_reshaped = reshape_scale(
self.equalization_scale, 0, min_inputs
)
min_input_scaled = torch.min(torch.mul(min_inputs, equalization_scale_reshaped))
max_input_scaled = torch.max(torch.mul(max_inputs, equalization_scale_reshaped))
return min_input_scaled, max_input_scaled
with_args = classmethod(_with_args)
class _WeightEqualizationObserver(nn.Module):
r"""Observer for tracking the running min/max values of weight columns and
rows, and computing the quantization parameters for the weight rows.
Args:
dtype: Quantized data type
qscheme: Quantization scheme
quant_min: Minimum quantization value. If unspecified, it will
follow the 8-bit setup.
quant_max: Maximum quantization value. If unspecified, it will
follow the 8-bit setup.
This observer is made up of 1 PerChannelMinMaxObserver `weight_col_obs` used
to record the running minimum and maximum of columns of incoming weight
tensors. This observer is intended to be used along with an
InputEqualizationObserver to calculate the equalization scale.
The running minimum/maximum :math:`w_\text{min/max}` are computed in the
same way as :class:`~torch.ao.quantization.observer.PerChannelMinMaxObserver`.
"""
def __init__(
self,
dtype=torch.qint8,
qscheme=torch.per_tensor_affine,
quant_min=None,
quant_max=None,
factory_kwargs=None,
) -> None:
super().__init__()
self.dtype = dtype
self.qscheme = qscheme
self.ch_axis = 1
per_channel_qscheme = qscheme
if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}:
per_channel_qscheme = qsheme_mapping_per_tensor_to_per_channel[qscheme]
self.weight_col_obs = PerChannelMinMaxObserver(
ch_axis=1,
dtype=dtype,
qscheme=per_channel_qscheme,
quant_min=quant_min,
quant_max=quant_max,
factory_kwargs=factory_kwargs,
)
self.equalization_scale = torch.tensor(1)
def forward(self, w_orig):
if not (w_orig.ndim >= 2 and w_orig.ndim <= 5):
raise ValueError(
"InputEqualizationObserver only supports Linear and Conv layers"
)
return self.weight_col_obs(w_orig)
def get_weight_col_minmax(self):
return (self.weight_col_obs.min_val, self.weight_col_obs.max_val)
def set_equalization_scale(self, equalization_scale):
self.equalization_scale = equalization_scale
with_args = classmethod(_with_args)
def calculate_equalization_scale(
input_obs: _InputEqualizationObserver, weight_obs: _WeightEqualizationObserver
) -> torch.Tensor:
r"""Calculates the equalization scale and sets the equalization_scale value
in the observers.
Args:
input_obs: Observer that tracks the ranges for the input columns
weight_obs: Observer that tracks the ranges for the weight columns
"""
(min_inputs, max_inputs) = input_obs.get_input_minmax()
(min_weights, max_weights) = weight_obs.get_weight_col_minmax()
if not (
check_min_max_valid(min_inputs, max_inputs)
and check_min_max_valid(min_weights, max_weights)
):
warnings.warn(
"Must run observer before calling calculate_equalization_scale. "
+ "Returning default equalization scale torch.tensor(1)."
)
return torch.tensor(1)
if not (min_inputs.shape == min_weights.shape):
raise ValueError(
"Input and Weight must have the same column dimension. "
+ f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
)
equalization_scale = torch.sqrt(
(max_weights - min_weights) / (max_inputs - min_inputs)
)
# Replace all 'inf', 'nan', 0's with 1s to prevent errors
equalization_scale[equalization_scale == 0.0] = 1
equalization_scale = torch.nan_to_num(equalization_scale, nan=1, posinf=1, neginf=1)
return equalization_scale
class EqualizationQConfig(
namedtuple("EqualizationQConfig", ["input_activation", "weight"])
):
"""
Describes how to quantize a layer or a part of the network specifically for
input-weight equalization by providing settings (observer classes) for
inputs, outputs, and weights.
Note that EqualizationQConfig needs to contain observer **classes** (like
MinMaxObserver) or a callable that returns instances on invocation, not the
concrete observer instances themselves.
Quantization function will instantiate observers multiple times for each of
the layers.
Observer classes have usually reasonable default arguments, but they can be
overwritten with `with_args` method (that behaves like functools.partial):
my_qconfig = EqualizationQConfig(input_activation=_InputEqualizationObserver.with_args(dtype=torch.qint8),
weight=_WeightEqualizationObserver.with_args(dtype=torch.qint8))
"""
def __new__(cls, input_activation=torch.nn.Identity, weight=torch.nn.Identity):
if isinstance(input_activation, nn.Module) or isinstance(weight, nn.Module):
raise ValueError(
"EqualizationQConfig received observer instance, please pass observer class instead. "
+ "Use MyObserver.with_args(x=1) to override arguments to constructor if needed"
)
self = super().__new__(cls, input_activation, weight)
return self
input_equalization_observer = _InputEqualizationObserver.with_args(
dtype=torch.quint8, qscheme=torch.per_tensor_symmetric
)
weight_equalization_observer = _WeightEqualizationObserver.with_args(
dtype=torch.qint8, qscheme=torch.per_channel_symmetric
)
default_equalization_qconfig = EqualizationQConfig(
input_activation=input_equalization_observer, weight=weight_equalization_observer
)
def fused_module_supports_equalization(module) -> bool:
"""Checks if the fused node supports equalization."""
return type(module) in [
nni.LinearReLU,
nni.ConvReLU1d,
nni.ConvReLU2d,
nni.ConvReLU3d,
]
def nn_module_supports_equalization(module) -> bool:
"""Checks if the torch.nn node supports equalization."""
return type(module) in [nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d]
def custom_module_supports_equalization(module) -> bool:
"""Checks if the custom node supports equalization."""
return type(module) in CUSTOM_MODULE_SUPP_LIST
def node_supports_equalization(node: Node, modules) -> bool:
"""Checks if the current node supports equalization
Currently we only support nn.Linear/F.Linear and nn.Conv/F.conv layers
"""
if node.op == "call_module":
return (
nn_module_supports_equalization(modules[str(node.target)])
or fused_module_supports_equalization(modules[str(node.target)])
or custom_module_supports_equalization(modules[str(node.target)])
)
elif node.op == "call_function":
return node.target in [F.linear, F.conv1d, F.conv2d, F.conv3d]
return False
def is_equalization_observer(observer: nn.Module) -> bool:
return isinstance(
observer, (_InputEqualizationObserver, _WeightEqualizationObserver)
)
###############################################################################
# Functions for equalization during convert #
###############################################################################
def get_op_node_and_weight_eq_obs(
input_eq_obs_node: Node, model: GraphModule, modules: Dict[str, nn.Module]
) -> Tuple[Optional[Node], Optional[_WeightEqualizationObserver]]:
"""Gets the following weight equalization observer. There should always
exist a weight equalization observer after an input equalization observer.
Returns the operation node that follows the input equalization observer node
and the weight equalization observer
"""
# Find the op node that comes directly after the input equalization observer
op_node = None
for user in input_eq_obs_node.users.keys():
if node_supports_equalization(user, modules):
op_node = user
break
assert op_node is not None
if op_node.op == "call_module":
# If the op_node is a nn.Linear layer, then it must have a
# WeightEqualizationObserver configuration
maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(
model, "equalization_node_name_to_qconfig"
)
assert maybe_equalization_node_name_to_config is not None
equalization_node_name_to_qconfig: Dict[str, Any] = maybe_equalization_node_name_to_config # type: ignore[assignment]
assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
weight_eq_obs = equalization_node_name_to_qconfig.get(
op_node.name, None
).weight()
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
return op_node, weight_eq_obs
elif op_node.op == "call_function":
weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
if weight_node is not None:
weight_eq_obs = modules[str(weight_node.target)]
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
return op_node, weight_eq_obs
return None, None
def maybe_get_weight_eq_obs_node(
op_node: Node, modules: Dict[str, nn.Module]
) -> Optional[Node]:
"""Gets the weight equalization observer node if it exists."""
assert op_node.op == "call_function"
for node_arg in op_node.args:
if node_arg_is_weight(op_node, node_arg):
assert (
isinstance(node_arg, Node)
and node_arg.op == "call_module"
and isinstance(
modules[str(node_arg.target)], _WeightEqualizationObserver
)
)
return node_arg
return None
def maybe_get_next_input_eq_obs(
node: Node, modules: Dict[str, nn.Module]
) -> Optional[_InputEqualizationObserver]:
"""Gets the following input equalization observer if it exists.
For example, in the case of connecting linear layers:
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
If the node being passed in is the linear1 node, then we want to return eq_obs2,
the following equalization observer for linear2.
However, if there are no connecting layers:
x -> inp_obs1 -> eq_obs1 -> linear1 -> out_obs1 -> add
Then we want to return None.
In the case of an unfused linear-relu layer with a connecting linear layer:
linear1 -> relu -> out_obs1 -> eq_obs2 -> linear2 -> out_obs2
Since it is unfused, we want to skip over the relu layer and return eq_obs2,
the following equalization observer for linear2.
"""
assert node_supports_equalization(node, modules)
# Locate the following nn.ReLU or F.relu node if it exists
maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
if maybe_relu_node is None:
maybe_relu_node = maybe_get_next_module(
node, modules, target_functional_type=F.relu
)
# Locate the following output observer if it exists.
# We will skip the relu node if it exists.
maybe_obs_node = (
maybe_get_next_module(node, modules, ObserverBase)
if maybe_relu_node is None
else maybe_get_next_module(maybe_relu_node, modules, ObserverBase)
)
if maybe_obs_node is None:
return None
maybe_eq_obs_node = maybe_get_next_module(
maybe_obs_node, modules, _InputEqualizationObserver
)
if maybe_eq_obs_node is None:
return None
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
return maybe_eq_obs
def maybe_get_next_equalization_scale(
node: Node, modules: Dict[str, nn.Module]
) -> Optional[torch.Tensor]:
"""If the next next node is an InputEqualizationObserver then we want to
return its equalization scale, else we return 1
This is used in the case where there are two connecting linear layers:
linear1 -> LinearOutObs -> InputEqObs -> linear2
In this case, the node given is linear1 and we want to locate the InputEqObs.
"""
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
if next_inp_eq_obs:
if (
next_inp_eq_obs.equalization_scale.nelement() == 1
and next_inp_eq_obs.equalization_scale == torch.tensor(1)
):
return None
return next_inp_eq_obs.equalization_scale
return None
def scale_input_observer(node: Node, modules: Dict[str, nn.Module]) -> None:
"""Scales the following input quantization observer's min/max values by
updating the values with the scaled min/max values calculated by the input
equalization observer
"""
input_eq_obs = modules[str(node.target)]
assert isinstance(input_eq_obs, _InputEqualizationObserver)
input_quant_obs_node = node.args[0]
assert isinstance(input_quant_obs_node, Node)
input_quant_obs = modules[str(input_quant_obs_node.target)]
if not isinstance(input_quant_obs, ObserverBase):
return
min_input_scaled, max_input_scaled = input_eq_obs.calculate_scaled_minmax()
if min_input_scaled is None and max_input_scaled is None:
return
input_quant_obs.min_val = min_input_scaled
input_quant_obs.max_val = max_input_scaled
def scale_weight_node(
node: Node,
modules: Dict[str, nn.Module],
equalization_scale: torch.Tensor,
next_equalization_scale: Optional[torch.Tensor],
) -> None:
"""Scale the weights for input-weight equalization by multiplying the
weight by 1/equalization_scale and next_equalization_scale
Args:
node: Current node whose weights we want to scale
equalization_scale: Current node's calculated equalization scale
next_equalization_scale: Next node's calculated equalization scale if
the following node needs to be equalized, 1 otherwise
"""
if equalization_scale is None:
return
if fused_module_supports_equalization(modules[str(node.target)]):
op_module = modules[str(node.target)][0] # type: ignore[index]
else:
op_module = modules[str(node.target)]
assert nn_module_supports_equalization(
op_module
) or custom_module_supports_equalization(op_module)
# Scale the weights for input-weight equalization
# If the following layer needs to be equalized then we will multiply its scale
weight = op_module.weight
assert isinstance(weight, torch.Tensor)
# Scale the weights by the reciprocal of the equalization scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
if next_equalization_scale is None:
op_module.weight = nn.Parameter(scaled_weight)
return
# Multiply the weights row wise by the next equalization scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=0
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, weight)
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
op_module.weight = nn.Parameter(scaled_weight)
# Multiply the bias element wise by the next equalization scale
bias = op_module.bias
if bias is None:
return
assert isinstance(bias, torch.Tensor)
# Reshape the equalization scale so that we can multiply it element-wise to the bias
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
op_module.bias = nn.Parameter(scaled_bias)
def scale_weight_functional(
op_node: Node,
model: GraphModule,
modules: Dict[str, nn.Module],
equalization_scale: torch.Tensor,
next_equalization_scale: Optional[torch.Tensor],
) -> None:
"""Scales the weight value for functional layers"""
if equalization_scale is None:
return
# From the given op_node, the path looks like:
# get_attr(weight) -> weight_quant_obs -> weight_eq_obs -> op_node
# So we want to trace back from the op_node to get the equalization observer
# node, then the quantization observer node, and then finally the weight
# node which contains the weight values.
# Get the equalization observer node
weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
if weight_eq_obs_node is None:
return
# Get the quantization observer node
weight_quant_obs_node = weight_eq_obs_node.args[0]
if weight_quant_obs_node is None:
return
assert isinstance(weight_quant_obs_node, Node) and isinstance(
modules[str(weight_quant_obs_node.target)], ObserverBase
)
# Get the get_attr(weight) node
weight_node = weight_quant_obs_node.args[0]
if weight_node is None:
return
assert isinstance(weight_node, Node) and weight_node.op == "get_attr"
weight_parent_name, weight_name = _parent_name(weight_node.target)
weight = getattr(modules[weight_parent_name], weight_name)
# Scale the weights for input-weight equalization
# If the following layer needs to be equalized then we will multiply its scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
equalization_scale_reshaped = reshape_scale(equalization_scale, 1, weight)
scaled_weight = torch.mul(weight, torch.reciprocal(equalization_scale_reshaped))
if next_equalization_scale is None:
setattr(modules[weight_parent_name], weight_name, scaled_weight)
return
# Multiply the weights row wise by the next equalization scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
next_equalization_scale_reshaped = reshape_scale(
next_equalization_scale, 0, scaled_weight
)
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
setattr(modules[weight_parent_name], weight_name, scaled_weight)
assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
# Multiply the bias element wise by the next equalization scale
bias_node = None
for node in op_node.args:
# Find the node containing the weight values
if isinstance(node, Node) and node.op == "get_attr" and "bias" in node.name:
bias_node = node
break
if bias_node is None:
return
bias_parent_name, bias_name = _parent_name(bias_node.target)
bias = getattr(modules[bias_parent_name], bias_name)
# Reshape the equalization scale so that we can multiply it element-wise to the bias
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
scaled_bias = torch.mul(bias, next_equalization_scale_reshaped)
setattr(modules[bias_parent_name], bias_name, scaled_bias)
def clear_weight_quant_obs_node(op_node: Node, modules: Dict[str, nn.Module]) -> None:
"""Given the operation node, we want find the corresponding quantization
observer and reset its min/max values
"""
weight_eq_obs_node = maybe_get_weight_eq_obs_node(op_node, modules)
if weight_eq_obs_node is None:
return
weight_quant_obs_node = weight_eq_obs_node.args[0]
if weight_quant_obs_node is None:
return
assert isinstance(weight_quant_obs_node, Node)
weight_quant_obs = modules[str(weight_quant_obs_node.target)]
assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
weight_quant_obs.reset_min_max_vals() # type: ignore[operator]
def remove_node(model: GraphModule, node: Node, prev_node: Node):
"""Removes the given node from the model by replacing all of its users with
the given previous node
"""
# For all of the current node's users, replace the current node with
# the input quantization observer node
orig_users = list(node.users.keys())
for user_node in orig_users:
user_node.replace_input_with(node, prev_node)
# Erase the InputEqualizationObserver node
model.graph.erase_node(node)
def update_obs_for_equalization(
model: GraphModule, modules: Dict[str, nn.Module]
) -> Dict[str, _WeightEqualizationObserver]:
"""Update all of the observer's equalization scale. For each
InputEqualizationObserver, we will find the location of the next
WeightEqualizationObserver, create it, and calculate the equalization scale
based on the two observers.
We will then return a dictionary mapping operation node names to
the corresponding WeightEqualizationObservers for that operation.
"""
weight_eq_obs_dict = {}
for node in model.graph.nodes:
if node.op == "call_module" and isinstance(
modules[node.target], _InputEqualizationObserver
):
input_eq_obs = modules[node.target]
assert isinstance(input_eq_obs, _InputEqualizationObserver)
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
if op_node is None or weight_eq_obs is None:
continue
if op_node.op == "call_module":
# Calibrate the weight equalization observer since it has just
# been created
if fused_module_supports_equalization(modules[str(op_node.target)]):
module = modules[str(op_node.target)][0] # type: ignore[index]
assert nn_module_supports_equalization(module)
weight_eq_obs(module.weight)
else:
weight_eq_obs(modules[str(op_node.target)].weight)
# Calculate and set the equalization scale values
equalization_scale = calculate_equalization_scale(
input_eq_obs, weight_eq_obs
)
input_eq_obs.set_equalization_scale(equalization_scale)
weight_eq_obs.set_equalization_scale(equalization_scale)
weight_eq_obs_dict[op_node.name] = weight_eq_obs
return weight_eq_obs_dict
def convert_eq_obs(
model: GraphModule,
modules: Dict[str, nn.Module],
weight_eq_obs_dict: Dict[str, _WeightEqualizationObserver],
) -> None:
"""Converts the equalization operations and updates the other nodes in the
following way:
- Removes the input equalization observers and inserts a mul operator
along with an equalization scale node wherever applicable (we do not
want to insert a mul operator between connecting linear layers).
- Updates the input quantization observers with the scaled input min/max
values.
- Scales the weights by the current and next equalization scales.
- Removes the weight equalization observer node if it exists.
Before (after prepare):
weight values
|
WeightQuantObs
|
WeightEqObs
|
x -> InpQuantObs -> InpEqObs -> linear -> OutQuantObs
After this function:
scaled weight values
|
equalization scale WeightQuantObs
| |
x -> mul -> InpQuantObs (scaled min/max) -> linear -> OutQuantObs
After convert:
equalization scale scaled weight values
| |
x -> mul -> quantize_per_tensor -> quantized::linear
Note that although the equalization observer appeared after the quantization
observer after prepare_fx, the mul node appears before the quantization node
after convert_fx. This is because placing the equalization observer after
the quantization observer in prepare_fx would allow us to keep the invariant
that the graph before the current node inserts its observers is not
modified.
Having the equalization observer before the quantization observer would also
cause some inconsistences between the ordering of the quantization and
equalization observers.
For example, a single linear layer would look like:
x -> InpEqObs1 -> InpQuantObs1 -> linear1 -> OutQuantObs1
But between two connected linear layers, it would look like:
linear1 -> OutQuantObs1 -> InpEqObs2 -> linear2 -> OutQuantObs2
"""
for node in model.graph.nodes:
if node.op == "call_module" and isinstance(
modules[node.target], _InputEqualizationObserver
):
inp_quant_obs_node = node.args[0]
prev_node = inp_quant_obs_node.args[0]
# If the previous node is a layer that needs to be equalized, then
# we will remove the current node because we do not need to add any
# equalization nodes between two layers that need to be equalized
# Before: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> input_eq_obs2 (node) -> linear2
# After: linear1/relu (prev_node) -> output_quant_obs1 (inp_quant_obs_node) -> linear2
if (
node_supports_equalization(prev_node, modules)
or "relu" in prev_node.name
):
remove_node(model, node, inp_quant_obs_node)
continue
# Update the following input quantization observer's min/max values
scale_input_observer(node, modules)
# Remove the InputEqualization node and add a mul operator before
# the quantization observer node that appears before the equalization node
# Before: x -> input_quant_obs -> input_eq_obs -> linear
# After: x -> mul -> input_quant_obs -> linear
# Create a node containing the equalization scale
with model.graph.inserting_before(inp_quant_obs_node):
get_new_eq_scale_name = get_new_attr_name_with_prefix(
prev_node.name + "_equalization_scale"
)
name = get_new_eq_scale_name(modules)
setattr(model, name, modules[node.target].equalization_scale)
eq_scale_node = model.graph.create_node("get_attr", name)
# Create a node multiplying the input with the equalization scale
with model.graph.inserting_after(eq_scale_node):
inputs = (prev_node, eq_scale_node)
mul_node = model.graph.create_node("call_function", torch.mul, inputs)
# Set the mul nod to be the input_quant_obs_node's input instead of
# the previous node
inp_quant_obs_node.replace_input_with(prev_node, mul_node)
remove_node(model, node, inp_quant_obs_node)
elif weight_eq_obs_dict.get(node.name, None) is not None:
weight_eq_obs = weight_eq_obs_dict.get(node.name)
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
equalization_scale = weight_eq_obs.equalization_scale
if (
equalization_scale.nelement() == 1
and equalization_scale == torch.tensor(1)
):
equalization_scale = None # type: ignore[assignment]
maybe_next_equalization_scale = maybe_get_next_equalization_scale(
node, modules
)
# Scale the weight nodes
if node.op == "call_module":
scale_weight_node(
node, modules, equalization_scale, maybe_next_equalization_scale
)
elif node.op == "call_function":
scale_weight_functional(
node,
model,
modules,
equalization_scale,
maybe_next_equalization_scale,
)
weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
if weight_eq_obs_node is None:
return
assert isinstance(
modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver
)
# Clear the quantization observer's min/max values so that they
# can get updated later based on the new scale values
clear_weight_quant_obs_node(node, modules)
# Erase the weight equalization observer node
prev_node = weight_eq_obs_node.args[0]
remove_node(model, weight_eq_obs_node, prev_node)
else:
raise ValueError(
"Expected operation node to be 'call_module' or 'call_function"
+ f"Instead got node {node.name} as '{node.op}'."
)
def _convert_equalization_ref(model: GraphModule):
"""Reference function which applies changes needed for equalization, but
does not quantize the nodes
"""
modules = dict(model.named_modules(remove_duplicate=False))
# Calculate the equalization scale, update the observers with the scaled
# inputs, and scale the weight
weight_eq_obs_dict = update_obs_for_equalization(model, modules)
convert_eq_obs(model, modules, weight_eq_obs_dict)
return GraphModule(model, model.graph)
###############################################################################
# Functions for running the equalized model on the Numeric Suite #
###############################################################################
def get_layer_sqnr_dict(
model_a: nn.Module, model_b: nn.Module, x: torch.Tensor
) -> Dict[str, float]:
"""Runs the Numeric Suite on model_a and model_b and returns a dictionary
containing the SQNR between layers in model_a and model_b.
Note: In order to support equalized models, this function has a hacky fix in
which we do not match any torch.mul operators. This is because equalized
models contain extra mul operators to scale the input by the equalization
scale, but this edge case has not been resolved yet within the numeric suite code.
Args:
model_a: A float model
model_b: A quantized model
x: Inputs to use during calibration
"""
import torch.ao.ns._numeric_suite_fx as ns
from torch.ao.ns.fx.mappings import get_unmatchable_types_map
unmatchable_types_map = get_unmatchable_types_map()
unmatchable_types_map["funs_unmatchable"].add(torch.mul)
model_a_ns, model_b_ns = ns.add_loggers(
"fp32",
model_a,
"int8",
model_b,
ns.OutputLogger,
unmatchable_types_map=unmatchable_types_map,
)
model_a_ns(x)
model_b_ns(x)
activation_comparison_dict = ns.extract_logger_info(
model_a_ns, model_b_ns, ns.OutputLogger, "int8"
)
ns.extend_logger_results_with_comparison(
activation_comparison_dict,
"fp32",
"int8",
torch.ao.ns.fx.utils.compute_sqnr,
"sqnr",
)
# Construct a dictionary mapping layer names to the SQNR values
layer_sqnr_dict = {}
for key in activation_comparison_dict:
layer = activation_comparison_dict[key]["node_output"]["int8"][0]["fqn"]
sqnr = activation_comparison_dict[key]["node_output"]["int8"][0]["sqnr"][0]
layer_sqnr_dict[layer] = sqnr
return layer_sqnr_dict
def get_equalization_qconfig_dict(
layer_sqnr_dict: Dict[str, float], num_layers_to_equalize: int
) -> Any:
"""Given the layer to SQNR dictionary, find the layers with the highest
quantization errors, and return an equalization_qconfig_dict
specifying to only equalize those top layers.
Args:
layer_sqnr_dict: Dictionary mapping layer names to SQNR values (found
when comparing an equalized model against a float model)
num_layers_to_equalize: Number of layers with the highest quantization
errors to equalize
"""
# Sort the layer_sqnr_dictionary values and get the layers with the lowest
# SQNR values (aka highest quantization errors)
layer_sqnr_sorted = sorted(layer_sqnr_dict.items(), key=operator.itemgetter(1))
layers_to_equalize = layer_sqnr_sorted[:num_layers_to_equalize]
# Constructs an equalization_qconfig_dict that specifies to only equalize
# the layers with the highest quantization errors
module_to_qconfig_list = [
(item[0], default_equalization_qconfig) for item in layers_to_equalize
]
equalization_qconfig_dict = {"module_name": module_to_qconfig_list}
return equalization_qconfig_dict
|