1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206
|
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import math
from typing import Optional, Tuple
import torch
from torch._refs import _unsqueeze_multiple
from torch.ao.quantization.utils import determine_qparams, validate_qmin_qmax
from torch.library import impl, Library
# Note: decomposed means decomposed quantized tensor, using decomposed so that the
# name is not too long
quantized_decomposed_lib = Library("quantized_decomposed", "DEF")
_INTEGER_DTYPES = [torch.uint8, torch.int8, torch.uint16, torch.int16, torch.int32]
_FLOAT_DTYPES = [torch.float8_e5m2, torch.float8_e4m3fn]
_DTYPE_TO_QVALUE_BOUNDS = {
k: (torch.iinfo(k).min, torch.iinfo(k).max) for k in _INTEGER_DTYPES
}
_DTYPE_TO_QVALUE_BOUNDS.update(
{k: (int(torch.finfo(k).min), int(torch.finfo(k).max)) for k in _FLOAT_DTYPES}
)
# Helper to check the passed in quant min and max are valid for the dtype
def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
raise ValueError(f"Unsupported dtype: {dtype}")
quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
assert quant_min >= quant_min_lower_bound, (
"quant_min out of bound for dtype, "
f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
)
assert quant_max <= quant_max_upper_bound, (
"quant_max out of bound for dtype, "
f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
)
quantized_decomposed_lib.define(
"quantize_per_tensor(Tensor input, float scale, int zero_point, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor"
)
@impl(quantized_decomposed_lib, "quantize_per_tensor", "CompositeExplicitAutograd")
def quantize_per_tensor(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""Affine quantization for the Tensor using the same quantization parameters to map
from floating point to quantized values
Args:
input (torch.Tensor): original float32 or bfloat16 Tensor
scale (float): quantization parameter for affine quantization
zero_point (int): quantization parameter for affine quantization
quant_min (int): minimum quantized value for output Tensor
quant_max (int): maximum quantized value for output Tensor
dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
Returns:
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
are not stored in the Tensor, we are storing them in function arguments instead
"""
if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32)
assert (
input.dtype == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
inv_scale = 1.0 / scale
return torch.clamp(
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
).to(dtype)
@impl(quantized_decomposed_lib, "quantize_per_tensor", "Meta")
def quantize_per_tensor_meta(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32)
assert (
input.dtype == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
return torch.empty_like(input, dtype=dtype)
quantized_decomposed_lib.define(
"quantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor"
)
@impl(
quantized_decomposed_lib, "quantize_per_tensor.tensor", "CompositeExplicitAutograd"
)
def quantize_per_tensor_tensor(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""Affine quantization for the Tensor using the same quantization parameters to map
from floating point to quantized values
Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values
"""
assert (
zero_point.numel() == 1
), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert (
scale.numel() == 1
), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
return quantize_per_tensor(
input, scale.item(), zero_point.item(), quant_min, quant_max, dtype
)
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor", "Meta")
def quantize_per_tensor_tensor_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32)
assert (
zero_point.numel() == 1
), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert (
scale.numel() == 1
), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
assert (
input.dtype == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
return torch.empty_like(input, dtype=dtype)
# TODO: remove other variants and keep this one
quantized_decomposed_lib.define(
"quantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
"Tensor quant_min, Tensor quant_max, ScalarType dtype) -> Tensor"
)
@impl(
quantized_decomposed_lib, "quantize_per_tensor.tensor2", "CompositeExplicitAutograd"
)
def quantize_per_tensor_tensor2(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: torch.Tensor,
quant_max: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
"""Affine quantization for the Tensor using the same quantization parameters to map
from floating point to quantized values
Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values
"""
assert (
zero_point.numel() == 1
), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert (
scale.numel() == 1
), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
return quantize_per_tensor(
input,
scale.item(),
zero_point.item(),
quant_min.item(),
quant_max.item(),
dtype,
)
@impl(quantized_decomposed_lib, "quantize_per_tensor.tensor2", "Meta")
def quantize_per_tensor_tensor2_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: torch.Tensor,
quant_max: torch.Tensor,
dtype: torch.dtype,
) -> torch.Tensor:
return quantize_per_tensor_tensor_meta(
input, scale, zero_point, quant_min, quant_max, dtype
)
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
# the signature as metadata for the input Tensor, this might be useful for pattern
# matching in the future
# We will revisit this later if we found there are no use cases for it
quantized_decomposed_lib.define(
"dequantize_per_tensor(Tensor input, float scale, int zero_point, "
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
)
@impl(quantized_decomposed_lib, "dequantize_per_tensor", "CompositeExplicitAutograd")
def dequantize_per_tensor(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""Affine dequantization for the Tensor using the same quantization parameters to map
from quantized values to floating point values
Args:
input (torch.Tensor): Tensor with dtype matching `dtype` argument,
e.g. (`torch.uint8`), it is a per tensor quantized Tensor if combined with
quantization parameters in the argument of this function (scale/zero_point)
scale (float): quantization parameter for affine quantization
zero_point (int): quantization parameter for affine quantization
quant_min (int): minimum quantized value for input Tensor (not used in computation,
reserved for pattern matching)
quant_max (int): maximum quantized value for input Tensor (not used in computation,
reserved for pattern matching)
dtype (torch.dtype): dtype for input Tensor (not used in computation,
reserved for pattern matching)
out_dtype (torch.dtype?): optional dtype for output Tensor
Returns:
dequantized float32 Tensor
"""
assert (
input.dtype == dtype
), f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
if out_dtype is None:
out_dtype = torch.float32
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
# TODO: investigate why
# (input - zero_point).to(torch.float32) * scale
# failed the test
return (input.to(out_dtype) - zero_point) * scale
else:
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
@impl(quantized_decomposed_lib, "dequantize_per_tensor", "Meta")
def dequantize_per_tensor_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if out_dtype is None:
out_dtype = torch.float32
return torch.empty_like(input, dtype=out_dtype)
quantized_decomposed_lib.define(
"dequantize_per_tensor.tensor(Tensor input, Tensor scale, Tensor zero_point, "
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
)
@impl(
quantized_decomposed_lib,
"dequantize_per_tensor.tensor",
"CompositeExplicitAutograd",
)
def dequantize_per_tensor_tensor(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""Affine dequantization for the Tensor using the same quantization parameters to map
from quantized values to floating point values
Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values
"""
assert (
zero_point.numel() == 1
), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert (
scale.numel() == 1
), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
return dequantize_per_tensor(
input,
scale.item(),
zero_point.item(),
quant_min,
quant_max,
dtype,
out_dtype=out_dtype,
)
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor", "Meta")
def dequantize_per_tensor_tensor_meta(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
if out_dtype is None:
out_dtype = torch.float32
assert (
zero_point.numel() == 1
), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert (
scale.numel() == 1
), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
return torch.empty_like(input, dtype=out_dtype)
else:
raise ValueError(f"Unsupported dtype in dequantize_per_tensor: {dtype}")
# TODO: remove other variants and keep this one
quantized_decomposed_lib.define(
"dequantize_per_tensor.tensor2(Tensor input, Tensor scale, Tensor zero_point, "
"Tensor quant_min, Tensor quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
)
@impl(
quantized_decomposed_lib,
"dequantize_per_tensor.tensor2",
"CompositeExplicitAutograd",
)
def dequantize_per_tensor_tensor2(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: torch.Tensor,
quant_max: torch.Tensor,
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""Affine dequantization for the Tensor using the same quantization parameters to map
from quantized values to floating point values
Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values
"""
assert (
zero_point.numel() == 1
), f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert (
scale.numel() == 1
), f"Expecting scale tensor to be one element, but received : {scale.numel()}"
return dequantize_per_tensor(
input,
scale.item(),
zero_point.item(),
quant_min.item(),
quant_max.item(),
dtype,
out_dtype=out_dtype,
)
@impl(quantized_decomposed_lib, "dequantize_per_tensor.tensor2", "Meta")
def dequantize_per_tensor_tensor2_meta(
input,
scale,
zero_point,
quant_min,
quant_max,
dtype,
*,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
return dequantize_per_tensor_tensor_meta(
input, scale, zero_point, quant_min, quant_max, dtype, out_dtype=out_dtype
)
quantized_decomposed_lib.define(
"choose_qparams.tensor(Tensor input, int quant_min, int quant_max, "
"float eps, ScalarType dtype) -> (Tensor, Tensor)"
)
@impl(quantized_decomposed_lib, "choose_qparams.tensor", "CompositeExplicitAutograd")
def choose_qparams_tensor(
input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Given an input Tensor, derive the per tensor affine quantization parameter
(scale and zero_point) for target quantized Tensor from the Tensor
Args:
input (torch.Tensor): floating point input Tensor
quant_min (int): minimum quantized value for target quantized Tensor
quant_max (int): maximum quantized value for target quantized Tensor
dtype (torch.dtype): dtype for target quantized Tensor
Returns:
scale (float): quantization parameter for the target quantized Tensor
zero_point (int): quantization parameter for the target quantized Tensor
"""
assert input.dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
assert (
dtype in _DTYPE_TO_QVALUE_BOUNDS
), f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
validate_qmin_qmax(qmin, qmax)
min_val, max_val = torch.aminmax(input)
return determine_qparams(
min_val,
max_val,
qmin,
qmax,
dtype,
torch.Tensor([eps]),
has_customized_qrange=False,
)
quantized_decomposed_lib.define(
"choose_qparams_symmetric.tensor(Tensor input, int quant_min, int quant_max, "
"float eps, ScalarType dtype) -> (Tensor, Tensor)"
)
@impl(
quantized_decomposed_lib,
"choose_qparams_symmetric.tensor",
"CompositeExplicitAutograd",
)
def choose_qparams_symmetric_tensor(
input: torch.Tensor, qmin: int, qmax: int, eps: float, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Given an input Tensor, derive the per tensor affine quantization parameter
(scale and zero_point) for target quantized Tensor from the Tensor
Args:
input (torch.Tensor): floating point input Tensor
quant_min (int): minimum quantized value for target quantized Tensor
quant_max (int): maximum quantized value for target quantized Tensor
dtype (torch.dtype): dtype for target quantized Tensor
Returns:
scale (float): quantization parameter for the target quantized Tensor
zero_point (int): quantization parameter for the target quantized Tensor
"""
assert input.dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
assert (
dtype in _DTYPE_TO_QVALUE_BOUNDS
), f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
validate_qmin_qmax(qmin, qmax)
min_val, max_val = torch.aminmax(input)
return determine_qparams(
min_val,
max_val,
qmin,
qmax,
dtype,
torch.Tensor([eps]),
has_customized_qrange=False,
qscheme=torch.per_tensor_symmetric,
)
@impl(quantized_decomposed_lib, "choose_qparams.tensor", "Meta")
def choose_qparams_tensor_meta(
input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
assert input.dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
], f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
assert (
quant_min < quant_max
), f"Expecting quant_min to be smaller than quant_max but received min: \
{quant_min} max: {quant_max}"
return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(
1, dtype=torch.int64, device=input.device
)
@impl(quantized_decomposed_lib, "choose_qparams_symmetric.tensor", "Meta")
def choose_qparams_symmetric_tensor_meta(
input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(
1, dtype=torch.int64, device=input.device
)
# Helper function used to implement per-channel quantization against any axis
def _permute_to_axis_zero(x, axis):
new_axis_list = list(range(x.dim()))
new_axis_list[axis] = 0
new_axis_list[0] = axis
y = x.permute(tuple(new_axis_list))
return y, new_axis_list
quantized_decomposed_lib.define(
"quantize_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor"
)
@impl(quantized_decomposed_lib, "quantize_per_channel", "CompositeExplicitAutograd")
def quantize_per_channel(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
axis: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
"""Affine per channel quantization for the Tensor using the same quantization
parameters for each channel/axis to map from floating point to quantized values
Args:
input (torch.Tensor): original float32 or bfloat16 Tensor
scales (torch.Tensor): a list of scale quantization parameter for
affine quantization, one per channel
zero_point (torch.Tensor): a list of zero_point quantization parameter for
affine quantization, one per channel
quant_min (int): minimum quantized value for output Tensor
quant_max (int): maximum quantized value for output Tensor
dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
Returns:
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
are not stored in the Tensor, we are storing them in function arguments instead
"""
if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32)
assert (
input.dtype == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
input, permute_axis_list = _permute_to_axis_zero(input, axis)
new_shape = [1] * input.dim()
new_shape[0] = scales.shape[0]
scales = scales.view(new_shape)
zero_points = zero_points.view(new_shape)
res = torch.clamp(
torch.round(input * (1.0 / scales)) + zero_points, quant_min, quant_max
)
out = res.permute(tuple(permute_axis_list))
return out.to(dtype)
@impl(quantized_decomposed_lib, "quantize_per_channel", "Meta")
def quantize_per_channel_meta(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
axis: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32)
assert (
input.dtype == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=dtype)
# Note: quant_min/quant_max/dtype are not used in the operator, but for now it's kept in
# the signature as metadata for the input Tensor, this might be useful for pattern
# matching in the future
# We will revisit this later if we found there are no use cases for it
quantized_decomposed_lib.define(
"dequantize_per_channel(Tensor input, Tensor scales, Tensor? zero_points, int axis, "
"int quant_min, int quant_max, ScalarType dtype, *, ScalarType? out_dtype=None) -> Tensor"
)
@impl(quantized_decomposed_lib, "dequantize_per_channel", "CompositeExplicitAutograd")
def dequantize_per_channel(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: Optional[torch.Tensor],
axis: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""Affine per channel dequantization for the Tensor using the same quantization
parameters for each channel/axis to map from quantized values to floating point values
Args:
input (torch.Tensor): Tensor with dtype matching `dtype` argument,
e.g. (`torch.uint8`), it is a per channel quantized Tensor if combined with
quantization parameter in the argument of this function (scales/zero_points/axis)
scales (torch.Tensor): a list of scale quantization parameter for
affine quantization, one per channel
zero_points (torch.Tensor): a list of zero_point quantization parameter for
affine quantization, one per channel
quant_min (int): minimum quantized value for output Tensor (not used in computation,
reserved for pattern matching)
quant_max (int): maximum quantized value for output Tensor (not used in computation,
reserved for pattern matching)
dtype (torch.dtype): requested dtype for output Tensor (not used in computation,
reserved for pattern matching)
out_dtype (torch.dtype?): optional dtype for output Tensor
Returns:
dequantized float32 Tensor
"""
assert (
input.dtype == dtype
), f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
if out_dtype is None:
out_dtype = torch.float32
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
input, permute_axis_list = _permute_to_axis_zero(input, axis)
new_shape = [1] * input.dim()
new_shape[0] = scales.shape[0]
scales = scales.view(new_shape)
if zero_points is not None:
res = (input - zero_points.view(new_shape)) * scales
else:
res = input * scales
res = res.to(out_dtype)
out = res.permute(tuple(permute_axis_list))
return out
@impl(quantized_decomposed_lib, "dequantize_per_channel", "Meta")
def dequantize_per_channel_meta(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: Optional[torch.Tensor],
axis: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
*,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
assert (
input.dtype == dtype
), f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
if out_dtype is None:
out_dtype = torch.float32
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=out_dtype)
quantized_decomposed_lib.define(
"choose_qparams_per_token(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
)
@impl(
quantized_decomposed_lib,
"choose_qparams_per_token",
"CompositeExplicitAutograd",
)
def choose_qparams_per_token(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
every N elements with the same quantization parameter. The dimension for scales/zero_points
will be (M1 * M2 ... * Mn)
Args:
input (torch.Tensor): original float32/float16 Tensor
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
Returns:
scales and zero_points, both float32 Tensors
"""
scales = input.abs().amax(dim=-1, keepdim=True)
if scales.dtype == torch.float16:
scales = (
scales.float()
) # want float scales to avoid overflows for fp16, (bf16 has wide enough range)
if dtype == torch.int8:
n_bits = 8
quant_max = 2 ** (n_bits - 1) - 1
else:
raise Exception( # noqa: TRY002
f"unsupported dtype in choose_qparams_per_token: {dtype}"
)
scales = scales.clamp(min=1e-5).div(quant_max)
zero_points = torch.zeros_like(scales)
return scales, zero_points
@impl(
quantized_decomposed_lib,
"choose_qparams_per_token",
"Meta",
)
def choose_qparams_per_token_meta(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
size = list(input.shape[:-1]) + [1]
return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
size, dtype=torch.int64, device=input.device
)
quantized_decomposed_lib.define(
"_choose_qparams_per_token_asymmetric_impl(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
)
@impl(
quantized_decomposed_lib,
"_choose_qparams_per_token_asymmetric_impl",
"CompositeImplicitAutograd",
)
def _choose_qparams_per_token_asymmetric_impl(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Choose quantization parameters for per token quantization. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
every N elements with the same quantization parameter. The dimension for scales/zero_points
will be (M1 * M2 ... * Mn)
Args:
input (torch.Tensor): original float32/float16 Tensor
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
Returns:
scales and zero_points, both float32 Tensors
"""
# Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
qmin, qmax = -128, 127
min_val = torch.amin(input, dim=-1, keepdim=True)
max_val = torch.amax(input, dim=-1, keepdim=True)
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
eps = torch.finfo(torch.float32).eps # use xnnpack eps?
# scale
scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
scale = scale.clamp(min=eps)
# zero point
descaled_min = min_val_neg / scale
descaled_max = max_val_pos / scale
zero_point_from_min_error = qmin + descaled_min
zero_point_from_max_error = qmax + descaled_max
zero_point = torch.where(
zero_point_from_min_error + zero_point_from_max_error > 0,
qmin - descaled_min,
qmax - descaled_max,
)
zero_point = torch.clamp(zero_point, qmin, qmax).round()
return scale.to(torch.float64), zero_point.to(torch.int64)
quantized_decomposed_lib.define(
"choose_qparams_per_token_asymmetric(Tensor input, ScalarType dtype) -> (Tensor, Tensor)"
)
@impl(
quantized_decomposed_lib,
"choose_qparams_per_token_asymmetric",
"CompositeExplicitAutograd",
)
def choose_qparams_per_token_asymmetric(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
return _choose_qparams_per_token_asymmetric_impl(input, dtype)
@impl(
quantized_decomposed_lib,
"choose_qparams_per_token_asymmetric",
"Meta",
)
def choose_qparams_per_token_asymmetric_meta(
input: torch.Tensor,
dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
size = list(input.shape[:-1]) + [1]
return torch.empty(size, dtype=torch.double, device=input.device), torch.empty(
size, dtype=torch.int64, device=input.device
)
def _per_token_quant_qparam_dim_check(input, scales, zero_points):
num_tokens = math.prod(list(input.size())[:-1])
assert (
num_tokens == scales.numel()
), f"num_tokens: {num_tokens} scales: {scales.size()}"
assert (
num_tokens == zero_points.numel()
), f"num_tokens: {num_tokens} zero_points: {zero_points.size()}"
quantized_decomposed_lib.define(
"quantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
"int quant_min, int quant_max, ScalarType dtype) -> Tensor"
)
@impl(quantized_decomposed_lib, "quantize_per_token", "CompositeExplicitAutograd")
def quantize_per_token(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
):
"""Per token quantization for the Tensor using the quantization parameters to map
from floating point to quantized values. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
every N elements with the same quantization parameter. The dimension for scales/zero_points
will be (M1 * M2 ... * Mn)
Args:
input (torch.Tensor): original float32 or bfloat16 Tensor
scales (float32 torch.Tensor): quantization parameter for per token affine quantization
zero_points (int32 torch.Tensor): quantization parameter for per token affine quantization
quant_min (int): minimum quantized value for output Tensor
quant_max (int): maximum quantized value for output Tensor
dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
Returns:
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
are not stored in the Tensor, we are storing them in function arguments instead
"""
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
_per_token_quant_qparam_dim_check(input, scales, zero_points)
input = (
input.mul(1.0 / scales)
.add(zero_points)
.round()
.clamp(quant_min, quant_max)
.to(dtype)
)
return input
@impl(quantized_decomposed_lib, "quantize_per_token", "Meta")
def quantize_per_token_meta(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
):
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=dtype)
quantized_decomposed_lib.define(
"dequantize_per_token(Tensor input, Tensor scales, Tensor zero_points, "
"int quant_min, int quant_max, ScalarType dtype, ScalarType output_dtype) -> Tensor"
)
@impl(quantized_decomposed_lib, "dequantize_per_token", "CompositeExplicitAutograd")
def dequantize_per_token(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
output_dtype: torch.dtype = torch.float32,
):
"""Per token dequantization for the Tensor using the quantization parameters to map
from floating point to quantized values. This means for a N dimension Tensor
(M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
every N elements with the same quantization parameter. The dimension for scales/zero_points
will be (M1 * M2 ... * Mn)
Args:
input (torch.Tensor): quantized Tensor (uint8, int8 etc.)
scales (float64 torch.Tensor): quantization parameter for per token affine quantization
zero_points (int64 torch.Tensor): quantization parameter for per token affine quantization
quant_min (int): minimum quantized value for input Tensor
quant_max (int): maximum quantized value for input Tensor
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor
Returns:
dequantized Tensor with dtype `output_dtype`
"""
input = input - zero_points
input = input * scales
# Since scales are of float64 type, we need to cast it to output dtype requested
return input.to(output_dtype)
@impl(quantized_decomposed_lib, "dequantize_per_token", "Meta")
def dequantize_per_token_meta(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
output_dtype: torch.dtype = torch.float32,
):
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
# TODO: support fp16
return torch.empty_like(input, dtype=output_dtype)
quantized_decomposed_lib.define(
"quantize_per_channel_group(Tensor input, Tensor scales, Tensor zero_points, int quant_min, "
"int quant_max, ScalarType dtype, int group_size) -> Tensor"
)
# TODO: dtype is ignored for now
@impl(
quantized_decomposed_lib, "quantize_per_channel_group", "CompositeExplicitAutograd"
)
def quantize_per_channel_group(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
group_size=128,
):
assert group_size > 1
# needed for GPTQ single column quantize
if group_size > input.shape[-1] and scales.shape[-1] == 1:
group_size = input.shape[-1]
assert input.shape[-1] % group_size == 0
assert input.dim() == 2
# TODO: check for dtype, currently we can't express torch.int4 so it's omitted
to_quant = input.reshape(-1, group_size)
assert torch.isnan(to_quant).sum() == 0
scales = scales.reshape(-1, 1)
zero_points = zero_points.reshape(-1, 1)
input_int8 = (
to_quant.mul(1.0 / scales)
.add(zero_points)
.round()
.clamp_(quant_min, quant_max)
.to(dtype)
.reshape_as(input)
)
return input_int8
@impl(quantized_decomposed_lib, "quantize_per_channel_group", "Meta")
def quantize_per_channel_group_meta(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
group_size=128,
):
"""Groupwise quantization within each channel for an 2-d Tensor using the quantization parameters
to map from floating point to quantized values. This means for each row of a 2-d Tensor
(M, N), we calculate scales/zero_points for each `group_size` elements
and quantize every `group_size` elements with the same quantization parameter.
The dimension for scales/zero_points will be (M * ceil(N, group_size),)
Args:
input (torch.Tensor): original float32 or bfloat16 Tensor
scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization
zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization
quant_min (int): minimum quantized value for output Tensor
quant_max (int): maximum quantized value for output Tensor
dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
Returns:
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
are not stored in the Tensor, we are storing them in function arguments instead
"""
assert group_size > 1
# needed for GPTQ single column quantize
if group_size > input.shape[-1] and scales.shape[-1] == 1:
group_size = input.shape[-1]
assert input.shape[-1] % group_size == 0
assert input.dim() == 2
return torch.empty_like(input, dtype=dtype)
quantized_decomposed_lib.define(
"dequantize_per_channel_group(Tensor input, Tensor scales, Tensor? zero_points, int quant_min, "
"int quant_max, ScalarType dtype, int group_size, ScalarType output_dtype) -> Tensor"
)
@impl(
quantized_decomposed_lib,
"dequantize_per_channel_group",
"CompositeExplicitAutograd",
)
def dequantize_per_channel_group(
w_int8: torch.Tensor,
scales: torch.Tensor,
zero_points: Optional[torch.Tensor],
quant_min: int,
quant_max: int,
dtype: torch.dtype,
group_size: int = 128,
output_dtype: torch.dtype = torch.float32,
):
"""Groupwise dequantization within each channel for an 2-d Tensor using the quantization parameters
to map from floating point to quantized values. This means for each row of a 2-d Tensor
(M, N), we calculate scales/zero_points for each `group_size` elements
and quantize every `group_size` elements with the same quantization parameter.
The dimension for scales/zero_points will be (M * ceil(N, group_size),)
Args:
input (torch.Tensor): quantized Tensor (uint8/int8 etc.)
scales (float32 torch.Tensor): quantization parameter for per channel group affine quantization
zero_points (int32 torch.Tensor): quantization parameter for per channel group affine quantization
quant_min (int): minimum quantized value for input Tensor
quant_max (int): maximum quantized value for input Tensor
dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
output_dtype (torch.dtype): dtype (e.g. torch.float32) for output Tensor
Returns:
dequantized Tensor with dtype `output_dtype`
"""
assert group_size > 1
# needed for GPTQ single column dequantize
if group_size > w_int8.shape[-1] and scales.shape[-1] == 1:
group_size = w_int8.shape[-1]
assert w_int8.shape[-1] % group_size == 0
assert w_int8.dim() == 2
w_int8_grouped = w_int8.reshape(-1, group_size)
scales = scales.reshape(-1, 1)
if zero_points is not None:
zp = zero_points.reshape(-1, 1)
else:
zp = torch.zeros([], dtype=torch.int32, device=scales.device)
w_dq = w_int8_grouped.sub(zp).mul(scales).reshape_as(w_int8).to(output_dtype)
return w_dq
quantized_decomposed_lib.define(
"fake_quant_per_channel(Tensor input, Tensor scales, Tensor zero_points, int axis, "
"int quant_min, int quant_max) -> Tensor"
)
class FakeQuantPerChannel(torch.autograd.Function):
@staticmethod
def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
if scales.dtype != torch.float32:
scales = scales.to(torch.float32)
if zero_points.dtype != torch.int32:
zero_points = zero_points.to(torch.int32)
assert (
input.dtype == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
broadcast_dims = list(range(0, axis)) + list(range(axis + 1, input.ndim))
unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims)
unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims)
temp = torch.round(input * (1.0 / unsqueeze_scales)) + unsqueeze_zero_points
out = (
torch.clamp(temp, quant_min, quant_max) - unsqueeze_zero_points
) * unsqueeze_scales
mask = torch.logical_and((temp >= quant_min), (temp <= quant_max))
ctx.save_for_backward(mask)
return out
@staticmethod
def backward(ctx, gy):
(mask,) = ctx.saved_tensors
return gy * mask, None, None, None, None, None
@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Autograd")
def fake_quant_per_channel(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
axis: int,
quant_min: int,
quant_max: int,
) -> torch.Tensor:
return FakeQuantPerChannel.apply(
input, scales, zero_points, axis, quant_min, quant_max
)
@impl(quantized_decomposed_lib, "fake_quant_per_channel", "Meta")
def fake_quant_per_channel_meta(
input: torch.Tensor,
scales: torch.Tensor,
zero_points: torch.Tensor,
axis: int,
quant_min: int,
quant_max: int,
) -> torch.Tensor:
return torch.empty_like(input)
quantized_decomposed_lib.define(
"convert_element_type.no_fuse(Tensor input, ScalarType dtype) -> Tensor"
)
@impl(
quantized_decomposed_lib,
"convert_element_type.no_fuse",
"CompositeExplicitAutograd",
)
def convert_element_type(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return torch.ops.prims.convert_element_type.default(input, dtype)
@impl(quantized_decomposed_lib, "convert_element_type.no_fuse", "Meta")
def convert_element_type_meta(input: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return torch.empty_like(input, dtype=dtype)
|