1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623 2624 2625 2626 2627 2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778 2779 2780 2781 2782 2783 2784 2785 2786 2787 2788 2789 2790 2791 2792 2793 2794 2795 2796 2797 2798 2799 2800 2801 2802 2803 2804 2805 2806 2807 2808 2809 2810 2811 2812 2813 2814 2815 2816 2817 2818 2819 2820 2821 2822 2823 2824 2825 2826 2827 2828 2829 2830 2831 2832 2833 2834 2835 2836 2837 2838 2839 2840 2841 2842 2843 2844 2845 2846 2847 2848 2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878 2879 2880 2881 2882 2883 2884 2885 2886 2887 2888 2889 2890 2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903 2904 2905 2906
|
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import copy
import functools
import itertools
import math
import operator
from typing import Any, Tuple
import torch
from torch._dynamo.utils import counters
from torch.fx.experimental.symbolic_shapes import has_free_symbols
from torch.fx.node import map_arg
from ..lowering import lowerings as L, require_channels_last
from ..pattern_matcher import Arg, CallFunction, filter_nodes, KeywordArg, ListOf, Match
from ..utils import pad_listlike
from .freezing_patterns import register_freezing_graph_pattern
from .post_grad import register_lowering_pattern
aten = torch.ops.aten
prims = torch.ops.prims
quantized_decomposed = torch.ops.quantized_decomposed
quantized = torch.ops.quantized
# Only for per tensor quant since permute may changes the channel idx
_PER_TENSOR_QUANTIZE_OPS = [
quantized_decomposed.quantize_per_tensor.default,
quantized_decomposed.quantize_per_tensor.tensor,
]
_VIEW_OPS = [
aten.transpose.int,
aten.permute.default,
aten.view.default,
]
"""
The quantization.py file primarily incorporates passes related to quantization fusion
in inductor, includes:
1. Dequant Promotion;
2. Conv/GEMM weight prepack with oneDNN Library;
3. Conv/GEMM quantization fusion with output quant node (if have);
4. Other pointwise operators' quantization fusion like: qmaxpool2d, qcat and more;
It also involves int8-mixed-fp32 and int8-mixed-bf16 quantization. The main difference
of patterns for int8-mixed-bf16, comparing with int8-mixed-fp32, is
1. There is to(dtype=torch.bfloat16) node at the inputs of activation and weight for Conv/GEMM.
2. There is to(dtype=torch.float32) node at the outputs of Conv/GEMM before inputs to next quant node.
Refer to: https://github.com/pytorch/pytorch/issues/111640 for detail design of int8-mixed-bf16
quantization.
"""
def _get_pattern_output_dtype(match: Match):
"""
Get the pattern's output dtype from node's meta
Assume only 1 output node in this matched pattern.
"""
pattern_output_nodes = match.output_nodes()
assert len(pattern_output_nodes) == 1
output_node = pattern_output_nodes[0]
assert isinstance(output_node, torch.fx.Node)
output_dtype = output_node.meta["val"].dtype
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
return output_dtype
def _may_generate_pattern_with_dtype_convert(
pattern, dtype=Arg(), with_dtype_convert=True, users=1
):
if with_dtype_convert:
return CallFunction(
prims.convert_element_type.default,
pattern,
dtype,
_users=users,
)
else:
return pattern
def _may_generate_pattern_with_reshape(pattern, reshape_size=Arg(), with_reshape=True):
if with_reshape:
return CallFunction(
torch.ops.aten.reshape.default,
pattern,
reshape_size,
)
else:
return pattern
def _generate_linear_t_pattern(
_dequant_per_channel_pattern,
dtype,
):
assert dtype in [torch.float32, torch.bfloat16]
t_pattern = CallFunction(
aten.permute.default,
_may_generate_pattern_with_dtype_convert(
_dequant_per_channel_pattern,
KeywordArg("autocast_wgt_dtype"),
dtype == torch.bfloat16,
),
KeywordArg("permute_axes"),
)
return t_pattern
def _unary_fusion_pattern(unary_fusion, call_fn, users, is_bf16):
# only insert to_dtype if is_bf16 is True
computation_call = _may_generate_pattern_with_dtype_convert(
call_fn, dtype=KeywordArg("to_float"), with_dtype_convert=is_bf16, users=users
)
return unary_fusion(computation_call)
def get_dequantize_per_tensor_activation_pattern(is_tensor_overload=False):
dequantize_per_tensor_activation_pattern = CallFunction(
quantized_decomposed.dequantize_per_tensor.tensor
if is_tensor_overload
else quantized_decomposed.dequantize_per_tensor.default,
KeywordArg("x"),
KeywordArg("x_scale"),
KeywordArg("x_zp"),
KeywordArg("x_quant_min"),
KeywordArg("x_quant_max"),
KeywordArg("x_dq_dtype"),
)
return dequantize_per_tensor_activation_pattern
dequantize_per_channel_weight_pattern = CallFunction(
quantized_decomposed.dequantize_per_channel.default,
KeywordArg("q_weight"),
KeywordArg("w_scale"),
KeywordArg("w_zp"),
KeywordArg("w_axis"),
KeywordArg("w_quant_min"),
KeywordArg("w_quant_max"),
KeywordArg("w_dtype"),
)
dequantize_per_channel_to_bf16_weight_pattern = (
_may_generate_pattern_with_dtype_convert(
dequantize_per_channel_weight_pattern,
KeywordArg("autocast_wgt_dtype"),
)
)
dequantize_per_channel_clone_weight_pattern = CallFunction(
aten.clone.default,
dequantize_per_channel_weight_pattern,
memory_format=KeywordArg("memory_format"),
)
dequantize_per_channel_to_bf16_clone_weight_pattern = CallFunction(
aten.clone.default,
dequantize_per_channel_to_bf16_weight_pattern,
memory_format=KeywordArg("memory_format"),
)
def get_dequantize_qconv_pt2e_pattern(users=1):
return CallFunction(
torch.ops.onednn.qconv2d_pointwise.default,
KeywordArg("x"),
KeywordArg("x_scale"), # x_scale
KeywordArg("x_zp"), # x_zp
KeywordArg("packed_weight"), # packed_weight
KeywordArg("w_scale"), # w_scale
KeywordArg("w_zp"), # w_zp
KeywordArg("b"), # bias
KeywordArg("stride"),
KeywordArg("padding"),
KeywordArg("dilation"),
KeywordArg("groups"),
KeywordArg("output_scale"), # output_scale = 1.0
KeywordArg("output_zero_point"), # output_zero_point = 0
KeywordArg("output_dtype"), # output_dtype = None
KeywordArg("attr"), # attr = "none"
Arg(), # scalars
Arg(), # algorithm
_users=users,
)
def get_qlinear_pt2e_pattern(x_scale_zp_are_tensors, users=1):
qlinear_op = (
torch.ops.onednn.qlinear_pointwise.tensor
if x_scale_zp_are_tensors
else torch.ops.onednn.qlinear_pointwise.default
)
return CallFunction(
qlinear_op,
KeywordArg("x"),
KeywordArg("x_scale"),
KeywordArg("x_zp"),
KeywordArg("packed_weight"),
KeywordArg("w_scale"),
KeywordArg("w_zp"),
KeywordArg("b"),
KeywordArg("output_scale"),
KeywordArg("output_zero_point"),
KeywordArg("output_dtype"),
KeywordArg("postop_name"),
KeywordArg("postop_args"),
KeywordArg("postop_algorithm"),
_users=users,
)
dequantize_accum_pattern = CallFunction(
quantized_decomposed.dequantize_per_tensor.default,
KeywordArg("accum"),
KeywordArg("accum_scale"),
KeywordArg("accum_zp"),
Arg(),
Arg(),
KeywordArg("accum_dq_dtype"),
)
def generate_pattern_with_binary(
binary_post_op,
computation_call,
extra_input_pattern,
dtype_convert=False,
swap_inputs=False,
):
binary_pattern = (
CallFunction(
binary_post_op,
extra_input_pattern,
computation_call,
)
if swap_inputs
else CallFunction(
binary_post_op,
computation_call,
extra_input_pattern,
)
)
return _may_generate_pattern_with_dtype_convert(
binary_pattern,
KeywordArg("convert_dtype_after_inplace_add"),
dtype_convert,
)
def generate_pattern_with_unary(computation_call, unary_post_op):
if unary_post_op is not None:
return CallFunction(
unary_post_op,
computation_call,
)
return computation_call
def generate_pattern_with_output_quant(computation_call, with_dtype_convert=False):
quantized_op_output_pattern_pt2e = CallFunction(
quantized_decomposed.quantize_per_tensor.default,
_may_generate_pattern_with_dtype_convert(
computation_call,
Arg(),
with_dtype_convert,
),
KeywordArg("o_inv_scale"),
KeywordArg("o_zp"),
KeywordArg("o_qmin"),
KeywordArg("o_qmax"),
KeywordArg("o_dtype"),
)
return quantized_op_output_pattern_pt2e
def _check_node_kwarg_arg_value(check_node, kwarg_name, args_index, expected_value):
if kwarg_name in check_node.kwargs:
actual_value = check_node.kwargs[kwarg_name]
return actual_value == expected_value
else:
assert len(check_node.args) >= (args_index + 1)
actual_value = check_node.args[args_index]
return actual_value == expected_value
def _is_valid_quantized_conv2d_optimization_pattern():
def fn(match):
output_dtype = _get_pattern_output_dtype(match)
if output_dtype in [torch.float32, torch.bfloat16]:
# Only keep matched pattern with same output_dtype
qconv_node_after_weight_prepack = filter_nodes(
match.nodes, torch.ops.onednn.qconv2d_pointwise
)[0]
return _check_node_kwarg_arg_value(
qconv_node_after_weight_prepack, "output_dtype", 13, output_dtype
)
return True
return fn
def _register_quantized_conv_lowering(
pattern,
pass_number,
computation_op,
unary_attr,
):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_quantized_conv2d_optimization_pattern(),
pass_number=pass_number,
)
def qconv(match: Match, *args, **kwargs):
# Activation QParams
x, x_scale, x_zp = (
kwargs["x"],
kwargs["x_scale"],
kwargs["x_zp"],
)
# Weight QParams
packed_weight, w_scale, w_zp = (
kwargs["packed_weight"],
kwargs["w_scale"],
kwargs["w_zp"],
)
# Conv Params
b, stride, padding, dilation, groups = (
kwargs["b"],
kwargs["stride"],
kwargs["padding"],
kwargs["dilation"],
kwargs["groups"],
)
output_dtype = _get_pattern_output_dtype(match)
assert output_dtype in [torch.int8, torch.uint8, torch.float32, torch.bfloat16]
# Output QParams
o_inv_scale = (
kwargs["o_inv_scale"]
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
else 1.0
)
o_zero_point = (
kwargs["o_zp"]
if (output_dtype == torch.uint8 or output_dtype == torch.int8)
else 0
)
assert (
kwargs["attr"] == "none"
) # Expected no post op fused in weight prepack phase
if unary_attr.op_name == "hardtanh":
min_value = kwargs.get("min_value")
max_value = kwargs.get("max_value")
unary_attr.scalars_attr = [min_value, max_value]
computation_args = (
x,
x_scale,
x_zp,
packed_weight,
w_scale,
w_zp,
b,
stride,
padding,
dilation,
groups,
o_inv_scale,
o_zero_point,
output_dtype,
unary_attr.op_name,
unary_attr.scalars_attr,
unary_attr.algorithm_attr,
)
counters["inductor"]["qconv2d_unary_matcher_count"] += 1
counters["inductor"]["qconv2d_unary_matcher_nodes"] += len(match.nodes)
return L[computation_op](*computation_args)
return qconv
def _is_valid_quantized_linear_optimization_pattern():
def fn(match):
output_dtype = _get_pattern_output_dtype(match)
if output_dtype in [torch.float32, torch.bfloat16]:
# Only keep matched pattern with same output_dtype
qlinear_node_after_weight_prepack = filter_nodes(
match.nodes, torch.ops.onednn.qlinear_pointwise
)[0]
return _check_node_kwarg_arg_value(
qlinear_node_after_weight_prepack, "output_dtype", 9, output_dtype
)
return True
return fn
def _register_quantized_linear_lowering(
pattern,
pass_number,
computation_op,
unary_attr,
):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_quantized_linear_optimization_pattern(),
pass_number=pass_number,
)
def qlinear(match: Match, *args, **kwargs):
output_dtype = _get_pattern_output_dtype(match)
# Activation QParams
x, x_scale, x_zp = (
kwargs["x"],
kwargs["x_scale"],
kwargs["x_zp"],
)
# Weight QParams
packed_weight, w_scale, w_zp = (
kwargs["packed_weight"],
kwargs["w_scale"],
kwargs["w_zp"],
)
# bias
b = kwargs["b"] if "b" in kwargs else None
# Output QParams
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
assert (
kwargs["postop_name"] == "none"
) # Expected no post op fused in weight prepack phase
computation_args = (
x,
x_scale,
x_zp,
packed_weight,
w_scale,
w_zp,
b,
o_inv_scale,
o_zero_point,
output_dtype,
unary_attr.op_name,
unary_attr.scalars_attr,
unary_attr.algorithm_attr,
)
counters["inductor"]["qlinear_unary_matcher_count"] += 1
counters["inductor"]["qlinear_unary_matcher_nodes"] += len(match.nodes)
return L[computation_op](*computation_args)
return qlinear
def _register_quantized_linear_binary_lowering(
pattern,
pass_number,
computation_op,
binary_unary_attr,
):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_qlinear_binary_optimization_pattern(),
pass_number=pass_number,
)
def qlinear_binary(match: Match, *args, **kwargs):
output_dtype = _get_pattern_output_dtype(match)
assert output_dtype is not None
# Activation QParams
x, x_scale, x_zp = (
kwargs["x"],
kwargs["x_scale"],
kwargs["x_zp"],
)
x2 = (
kwargs["accum"]
if binary_unary_attr.binary_op_name == "sum"
else kwargs["other"]
)
x2_scale = 1.0
x2_zp = 0
# Weight QParams
packed_weight, w_scale, w_zp = (
kwargs["packed_weight"],
kwargs["w_scale"],
kwargs["w_zp"],
)
# bias
b = kwargs["b"] if "b" in kwargs else None
# Output QParams
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
x2.realize()
from .mkldnn_fusion import _can_be_inplace
binary_op_name = binary_unary_attr.binary_op_name
if binary_op_name == "sum" and not _can_be_inplace(x2):
# When we enable the GEMM Template, the output of QLinear
# will be reshaped from 2D back to 3D if the input is 3D.
# This causes _can_be_inplace(x2) to return False if x2 happens
# to be the output of QLinear in this scenario.
# Change the post op from sum to binary add for this case.
# Refer to test case:
# test_mkldnn_pattern_matcher.py::test_qlinear_dequant_promotion_cpu_input_dim_exceeds_2
binary_op_name = "add"
computation_args = (
x,
x_scale,
x_zp,
packed_weight,
w_scale,
w_zp,
x2,
b,
o_inv_scale,
o_zero_point,
output_dtype,
x2_scale,
x2_zp,
binary_op_name,
binary_unary_attr.alpha,
binary_unary_attr.unary_op_name,
binary_unary_attr.scalars_attr,
binary_unary_attr.algorithm_attr,
)
counters["inductor"]["qlinear_binary_matcher_count"] += 1
counters["inductor"]["qlinear_binary_matcher_nodes"] += len(match.nodes)
return L[computation_op](*computation_args)
return qlinear_binary
def _is_valid_qconv_binary_optimization_pattern():
return _is_valid_quantized_op_binary_optimization_pattern(
torch.ops.onednn.qconv2d_pointwise
)
def _is_valid_qlinear_binary_optimization_pattern():
return _is_valid_quantized_op_binary_optimization_pattern(
torch.ops.onednn.qlinear_pointwise,
# we don't insert q-dq for extra input due to accuracy issues
extra_input_from_dequant=False,
)
def _is_valid_quantized_op_binary_optimization_pattern(
qop, extra_input_from_dequant=True
):
# Check if it's a valid Binary Pattern for qconv2d and qlinear:
# * qop_pointwise should only has one users
# * If extra_input_from_dequant is True, extra input of binary node should come from dequant pattern
# * the two inputs of binary node should have attribute "meta" and should be tensors
# * the two inputs of binary node should have the same shape
# * All users of the extra input in this pattern should be
# ancestor nodes of the compute node, except for the binary node
# connected to the compute node.
def fn(match):
output_dtype = _get_pattern_output_dtype(match)
compute_node = filter_nodes(match.nodes, qop)[0]
# qop_pointwise should only have one user
if len(compute_node.users) != 1:
return False
binary_node_inputs = next(iter(compute_node.users)).args
assert len(binary_node_inputs) == 2, "Expects binary node with 2 inputs"
if output_dtype in [torch.float32, torch.bfloat16]:
extra_input_of_binary_node = None
for arg in binary_node_inputs:
if arg != compute_node:
extra_input_of_binary_node = arg
break
assert extra_input_of_binary_node is not None
# Extra input of binary node comes from dequant pattern
if extra_input_from_dequant and (
(not isinstance(extra_input_of_binary_node, torch.fx.Node))
or (
extra_input_of_binary_node.target
!= quantized_decomposed.dequantize_per_tensor.default
)
):
return False
# the two inputs of binary node should have attribute "meta" and should be tensors
if not (
hasattr(binary_node_inputs[0], "meta")
and isinstance(binary_node_inputs[0].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
) or not (
hasattr(binary_node_inputs[1], "meta")
and isinstance(binary_node_inputs[1].meta.get("val", None), torch.Tensor) # type: ignore[union-attr]
):
return False
# the two inputs of binary node should have the same shape
if (
binary_node_inputs[0].meta["val"].size() # type: ignore[union-attr]
!= binary_node_inputs[1].meta["val"].size() # type: ignore[union-attr]
):
return False
# All users of the extra input in this pattern should be
# ancestor nodes of the compute node, except for the binary node
# connected to the compute node.
from .mkldnn_fusion import _get_remaining_users
extra_input_of_pattern = (
match.kwargs["other"]
if "other" in match.kwargs
else (
match.kwargs["accum"]
if output_dtype == torch.uint8 or (not extra_input_from_dequant)
else match.kwargs["accum_after_dequant"]
)
)
if (
len(_get_remaining_users(extra_input_of_pattern, compute_node)) > 1
or extra_input_of_pattern == compute_node.args[0]
):
return False
return True
return fn
def _register_quantized_conv_binary_lowering(
pattern,
pass_number,
computation_op,
binary_unary_attr,
):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_qconv_binary_optimization_pattern(),
pass_number=pass_number,
)
def qconv_binary(match: Match, *args, **kwargs):
output_dtype = _get_pattern_output_dtype(match)
assert output_dtype is not None
x, x_scale, x_zp = kwargs["x"], kwargs["x_scale"], kwargs["x_zp"]
accum = (
kwargs["accum"]
if output_dtype == torch.uint8
else kwargs["accum_after_dequant"]
)
accum_scale = kwargs["accum_scale"] if output_dtype == torch.uint8 else 1.0
accum_zp = kwargs["accum_zp"] if output_dtype == torch.uint8 else 0
packed_weight, w_scale, w_zp = (
kwargs["packed_weight"],
kwargs["w_scale"],
kwargs["w_zp"],
)
b, stride, padding, dilation, groups = (
kwargs["b"],
kwargs["stride"],
kwargs["padding"],
kwargs["dilation"],
kwargs["groups"],
)
# Output QParams
o_inv_scale = kwargs["o_inv_scale"] if output_dtype == torch.uint8 else 1.0
o_zero_point = kwargs["o_zp"] if output_dtype == torch.uint8 else 0
accum.realize()
from .mkldnn_fusion import _can_be_inplace
assert _can_be_inplace(
accum
), "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
computation_args = (
x,
x_scale,
x_zp,
packed_weight,
w_scale,
w_zp,
accum,
b,
stride,
padding,
dilation,
groups,
o_inv_scale,
o_zero_point,
output_dtype,
accum_scale,
accum_zp,
binary_unary_attr.binary_op_name,
binary_unary_attr.alpha,
binary_unary_attr.unary_op_name,
binary_unary_attr.scalars_attr,
binary_unary_attr.algorithm_attr,
)
counters["inductor"]["qconv2d_binary_matcher_count"] += 1
counters["inductor"]["qconv2d_binary_matcher_nodes"] += len(match.nodes)
return L[computation_op](*computation_args)
return qconv_binary
def _register_quantization_unary_fusion():
from .mkldnn_fusion import (
_gelu_fusion_1 as _gelu_fusion_erf,
_gelu_fusion_2 as _gelu_fusion_tanh,
_hardswish_fusion,
_hardtanh_fusion,
_silu_fusion,
)
class UnaryAttr:
def __init__(
self, op_name: str, scalars_attr=None, algorithm_attr=None
) -> None:
self.op_name = op_name
self.scalars_attr = scalars_attr if scalars_attr else []
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
for original_pattern_output_dtype in [torch.float32, torch.bfloat16]:
# QConv2d
# Priority 1 to match: QConv2d Unary pattern with int8 output
# If a pattern1 is a sub-set of pattern2, we should try to match pattern2 firstly.
# For example: pattern1 is qconv_fp32 -> relu, pattern2 is qconv_fp32 -> relu -> quant
is_bf16 = original_pattern_output_dtype == torch.bfloat16
conv_unary_replace_patterns = {
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
get_dequantize_qconv_pt2e_pattern(1),
),
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
),
),
UnaryAttr("hardtanh", [], ""): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_hardtanh_fusion,
get_dequantize_qconv_pt2e_pattern(1),
1,
is_bf16,
),
with_dtype_convert=is_bf16,
),
UnaryAttr("hardswish", [], ""): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_hardswish_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
with_dtype_convert=is_bf16,
),
UnaryAttr("swish", [], ""): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_silu_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
with_dtype_convert=is_bf16,
),
}
for unary_attr, patterns in conv_unary_replace_patterns.items():
# Register qconv2d pattern for ExternKernel Lowering
_register_quantized_conv_lowering(
patterns,
1, # pass_number
torch.ops.onednn.qconv2d_pointwise, # computation_op
unary_attr, # unary_attr
)
# Priority 2 to match: QConv2d Unary pattern with fp32/bfloat16 output
conv_unary_replace_float_out_patterns = {
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
get_dequantize_qconv_pt2e_pattern(1), aten.relu.default
),
UnaryAttr("hardtanh", [], ""): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_hardtanh_fusion,
get_dequantize_qconv_pt2e_pattern(1),
1,
is_bf16,
),
Arg(),
is_bf16,
),
UnaryAttr("hardswish", [], ""): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_hardswish_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
Arg(),
is_bf16,
),
UnaryAttr("swish", [], ""): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_silu_fusion,
get_dequantize_qconv_pt2e_pattern(1 if is_bf16 else 2),
2,
is_bf16,
),
Arg(),
is_bf16,
),
}
for unary_attr, patterns in conv_unary_replace_float_out_patterns.items():
# Register qconv2d pattern for ExternKernel Lowering
_register_quantized_conv_lowering(
patterns,
2, # pass_number
torch.ops.onednn.qconv2d_pointwise, # computation_op
unary_attr, # unary_attr
)
# QLinear
for x_scale_zp_are_tensors in (False, True):
qlinear_pattern = get_qlinear_pt2e_pattern(x_scale_zp_are_tensors)
# Priority 1 to match: QLinear Unary pattern with int8 output
linear_unary_replace_patterns = {
UnaryAttr("none", [], ""): generate_pattern_with_output_quant(
qlinear_pattern,
),
UnaryAttr("relu", [], ""): generate_pattern_with_output_quant(
generate_pattern_with_unary(qlinear_pattern, aten.relu.default),
),
UnaryAttr("gelu", [], "none"): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_gelu_fusion_erf,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 2
),
2,
is_bf16,
),
with_dtype_convert=is_bf16,
),
UnaryAttr("gelu", [], "tanh"): generate_pattern_with_output_quant(
_unary_fusion_pattern(
_gelu_fusion_tanh,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 4
),
4,
is_bf16,
),
with_dtype_convert=is_bf16,
),
}
for unary_attr, patterns in linear_unary_replace_patterns.items():
_register_quantized_linear_lowering(
patterns,
1, # pass_number
torch.ops.onednn.qlinear_pointwise, # computation_op
unary_attr, # unary_attr
)
# Priority 2 to match: QLinear Unary pattern with FP32/BF16 output
linear_unary_replace_float_out_patterns = {
UnaryAttr("relu", [], ""): generate_pattern_with_unary(
qlinear_pattern, aten.relu.default
),
UnaryAttr("gelu", [], "none"): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_gelu_fusion_erf,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 2
),
2,
is_bf16,
),
Arg(),
is_bf16,
),
UnaryAttr("gelu", [], "tanh"): _may_generate_pattern_with_dtype_convert(
_unary_fusion_pattern(
_gelu_fusion_tanh,
get_qlinear_pt2e_pattern(
x_scale_zp_are_tensors, 1 if is_bf16 else 4
),
4,
is_bf16,
),
Arg(),
is_bf16,
),
}
for unary_attr, patterns in linear_unary_replace_float_out_patterns.items():
_register_quantized_linear_lowering(
patterns,
2, # pass_number
torch.ops.onednn.qlinear_pointwise, # computation_op
unary_attr, # unary_attr
)
def _register_quantization_binary_fusion():
class BinaryUnaryAttr:
def __init__(
self,
binary_op_name: str,
alpha=None,
unary_op_name: str = "none",
scalars_attr=None,
algorithm_attr=None,
) -> None:
self.binary_op_name = binary_op_name
self.alpha = alpha if alpha else 1.0
self.unary_op_name = unary_op_name
self.scalars_attr = scalars_attr if scalars_attr else []
self.algorithm_attr = algorithm_attr if algorithm_attr else ""
for int8_mixed_bf16_with_inplace_add in [False, True]:
# Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output
swap_binary_inputs_list = [False, True]
binary_replace_patterns = {}
for swap_inputs in swap_binary_inputs_list:
binary_replace_patterns.update(
{
BinaryUnaryAttr(
"sum", 1.0, "none", [], ""
): generate_pattern_with_output_quant(
generate_pattern_with_binary(
aten.add.Tensor,
get_dequantize_qconv_pt2e_pattern(1),
dequantize_accum_pattern,
int8_mixed_bf16_with_inplace_add,
swap_inputs=swap_inputs,
),
),
BinaryUnaryAttr(
"sum", 1.0, "relu", [], ""
): generate_pattern_with_output_quant(
generate_pattern_with_unary(
generate_pattern_with_binary(
aten.add.Tensor,
get_dequantize_qconv_pt2e_pattern(1),
dequantize_accum_pattern,
int8_mixed_bf16_with_inplace_add,
swap_inputs=swap_inputs,
),
aten.relu.default,
),
),
}
)
for binary_unary_attr, patterns in binary_replace_patterns.items():
_register_quantized_conv_binary_lowering(
patterns,
0, # pass_number
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
binary_unary_attr, # binary_unary_attr
)
# Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output
binary_replace_float_out_patterns = {}
for swap_inputs in swap_binary_inputs_list:
binary_replace_float_out_patterns.update(
{
BinaryUnaryAttr(
"sum", 1.0, "relu", [], ""
): generate_pattern_with_unary(
generate_pattern_with_binary(
aten.add.Tensor,
get_dequantize_qconv_pt2e_pattern(1),
KeywordArg("accum_after_dequant"),
int8_mixed_bf16_with_inplace_add,
swap_inputs=swap_inputs,
),
aten.relu.default,
)
}
)
for (
binary_unary_attr,
patterns,
) in binary_replace_float_out_patterns.items():
if int8_mixed_bf16_with_inplace_add:
_register_quantized_conv_binary_lowering(
patterns,
0, # pass_number
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
binary_unary_attr, # binary_unary_attr
)
else:
_register_quantized_conv_binary_lowering(
patterns,
1, # pass_number
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
binary_unary_attr, # binary_unary_attr
)
# Priority 3: QConv2d Binary pattern with fp32/bfloat16 output
binary_replace_float_out_patterns = {}
for swap_inputs in swap_binary_inputs_list:
binary_replace_float_out_patterns.update(
{
BinaryUnaryAttr(
"sum", 1.0, "none", [], ""
): generate_pattern_with_binary(
aten.add.Tensor,
get_dequantize_qconv_pt2e_pattern(1),
KeywordArg("accum_after_dequant"),
int8_mixed_bf16_with_inplace_add,
swap_inputs=swap_inputs,
),
}
)
for (
binary_unary_attr,
patterns,
) in binary_replace_float_out_patterns.items():
_register_quantized_conv_binary_lowering(
patterns,
1 if int8_mixed_bf16_with_inplace_add else 2, # pass_number
torch.ops.onednn.qconv2d_pointwise.binary, # computation_op
binary_unary_attr, # binary_unary_attr
)
# QLinear
r"""
Supported linear-binary(-unary) patterns
linear(X) extra input
\ /
Add
|
Optional(relu)
|
Y
1. int8-mixed-fp32
+---+---------------+-----------+------------------------------+---------+
| # | Add type | Quant out | Pattern | Post op |
+---+---------------+-----------+------------------------------+---------+
| 1 | In-/out-place | Yes | linear + fp32 -> (relu) -> q | add |
+---+---------------+-----------+------------------------------+---------+
| 2 | In-/out-place | No | linear + fp32 -> (relu) | sum |
+---+---------------+-----------+------------------------------+---------+
2. int8-mixed-bf16
+---+----------+---------------+-----------+-----------------------------------------+---------+
| # | X2 dtype | Add type | Quant out | Pattern | Post op |
+---+----------+---------------+-----------+-----------------------------------------+---------+
| 1 | BF16 | In-/out-place | Yes | linear + bf16 -> (relu) -> q | add |
+---+----------+---------------+-----------+-----------------------------------------+---------+
| 2 | BF16 | In-/out-place | No | linear + bf16 -> (relu) | sum |
+---+----------+---------------+-----------+-----------------------------------------+---------+
| 3 | FP32 | Out-place | Yes | linear + fp32 -> (relu) -> q | add |
| | | In-place right| | | |
+---+----------+---------------+-----------+-----------------------------------------+---------+
| 4 | FP32 | Out-place | No | linear + fp32 -> (relu) | sum |
| | | In-place right| | | |
+---+----------+---------------+-----------+-----------------------------------------+---------+
| 5 | FP32 | In-place left | Yes | linear + fp32 -> to_bf16 -> (relu) -> q | add |
+---+----------+---------------+-----------+-----------------------------------------+---------+
| 6 | FP32 | In-place left | No | linear + fp32 -> to_bf16 -> (relu) | add |
+---+----------+---------------+-----------+-----------------------------------------+---------+
Note
(1) The positions of linear and the extra input can be swapped.
(2) we don't insert q-dq before the extra input of linear-add by recipe. But if q-dq is found at the
extra input, we don't match that pattern because we cannot match all these patterns in 3 passes.
"""
for x_scale_zp_are_tensors in (False, True):
qlinear_binary_op = (
torch.ops.onednn.qlinear_pointwise.binary_tensor
if x_scale_zp_are_tensors
else torch.ops.onednn.qlinear_pointwise.binary
)
unary_postop_list = ["none", "relu"]
unary_postop_dict = {
"none": None,
"relu": aten.relu.default,
}
convert_dtype_after_binary_list = [False, True]
# Priority 1 to match: QLinear Binary or Binary-Unary pattern with int8 output
# Covers case (1) of int8-mixed-fp32 and case (1)(3)(5) of int8-mixed-bf16,
# totally 3 patterns (2 are identical)
swap_binary_inputs_list = [False, True]
int8_mixed_bf16_list = [False, True]
combinations = itertools.product(
unary_postop_list,
int8_mixed_bf16_list,
swap_binary_inputs_list,
convert_dtype_after_binary_list,
)
qlinear_binary_replace_patterns = {}
for unary_op, int8_mixed_bf16, swap_inputs, cvt_dtype_binary in combinations:
if not int8_mixed_bf16 and cvt_dtype_binary:
# No convert node after binary node if dtypes are all fp32
continue
qlinear_binary_replace_patterns.update(
{
BinaryUnaryAttr(
"add", 1.0, unary_op, [], ""
): generate_pattern_with_output_quant(
generate_pattern_with_unary(
generate_pattern_with_binary(
aten.add.Tensor,
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
KeywordArg("other"),
# If fp32 extra input is inplace added to bf16 linear output,
# a to_bf16 node is inserted after binary
dtype_convert=cvt_dtype_binary,
swap_inputs=swap_inputs,
),
unary_postop_dict[unary_op],
),
)
}
)
for binary_unary_attr, patterns in qlinear_binary_replace_patterns.items():
_register_quantized_linear_binary_lowering(
patterns,
0, # pass_number
qlinear_binary_op, # computation_op
binary_unary_attr, # binary_unary_attr
)
# Priority 2.1 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
# Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
# totally 2 patterns (2 are identical)
binary_replace_float_out_patterns = {}
for swap_binary_inputs in swap_binary_inputs_list:
binary_replace_float_out_patterns.update(
{
BinaryUnaryAttr(
"sum", 1.0, "relu", [], ""
): generate_pattern_with_unary(
generate_pattern_with_binary(
aten.add.Tensor,
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
KeywordArg("accum"),
dtype_convert=False,
swap_inputs=swap_binary_inputs,
),
aten.relu.default,
),
}
)
for (
binary_unary_attr,
patterns,
) in binary_replace_float_out_patterns.items():
_register_quantized_linear_binary_lowering(
patterns,
1, # pass_number
qlinear_binary_op, # computation_op
binary_unary_attr,
)
# Priority 2.2 to match: QLinear Binary-Unary pattern with fp32/bfloat16 output
# Covers case (6) of int8-mixed-bf16
binary_replace_float_out_patterns = {}
for swap_binary_inputs in swap_binary_inputs_list:
binary_replace_float_out_patterns.update(
{
BinaryUnaryAttr(
"add", 1.0, "relu", [], ""
): generate_pattern_with_unary(
generate_pattern_with_binary(
aten.add.Tensor,
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
KeywordArg("other"),
dtype_convert=True,
swap_inputs=swap_binary_inputs,
),
aten.relu.default,
),
}
)
for (
binary_unary_attr,
patterns,
) in binary_replace_float_out_patterns.items():
_register_quantized_linear_binary_lowering(
patterns,
1, # pass_number
qlinear_binary_op, # computation_op
binary_unary_attr,
)
# Priority 3.1: QLinear Binary pattern with fp32/bfloat16 output
# Covers case (2) of int8-mixed-fp32 and case (2)(4) of int8-mixed-bf16,
# totally 2 patterns (2 are identical)
binary_replace_float_out_patterns = {}
for swap_binary_inputs in swap_binary_inputs_list:
binary_replace_float_out_patterns.update(
{
BinaryUnaryAttr(
"sum", 1.0, "none", [], ""
): generate_pattern_with_binary(
aten.add.Tensor,
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
KeywordArg("accum"),
dtype_convert=False,
swap_inputs=swap_binary_inputs,
),
}
)
for (
binary_unary_attr,
patterns,
) in binary_replace_float_out_patterns.items():
_register_quantized_linear_binary_lowering(
patterns,
2, # pass_number
qlinear_binary_op, # computation_op
binary_unary_attr,
)
# Priority 3.2: QLinear Binary pattern with fp32/bfloat16 output
# Covers (6) of int8-mixed-bf16
binary_replace_float_out_patterns = {}
for swap_binary_inputs in swap_binary_inputs_list:
binary_replace_float_out_patterns.update(
{
BinaryUnaryAttr(
"add", 1.0, "none", [], ""
): generate_pattern_with_binary(
aten.add.Tensor,
get_qlinear_pt2e_pattern(x_scale_zp_are_tensors),
KeywordArg("other"),
dtype_convert=True,
swap_inputs=swap_binary_inputs,
),
}
)
for (
binary_unary_attr,
patterns,
) in binary_replace_float_out_patterns.items():
_register_quantized_linear_binary_lowering(
patterns,
2, # pass_number
qlinear_binary_op, # computation_op
binary_unary_attr,
)
def _is_valid_quantized_maxpool2d_optimization_pattern():
def fn(match):
# Only match the pattern which max_pool2d_with_indices returns value
# instead of indices.
get_item_node = filter_nodes(match.nodes, operator.getitem)[0]
return get_item_node.args[1] == 0
return fn
def _register_quantized_maxpool2d_lowering(
pattern,
computation_op,
):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_quantized_maxpool2d_optimization_pattern(),
)
def qmaxpool2d(match: Match, *args, **kwargs):
x = kwargs["x"]
kernel_size = kwargs["kernel_size"]
stride = kwargs["stride"] if ("stride" in kwargs) else None
padding = kwargs["padding"] if ("padding" in kwargs) else 0
dilation = kwargs["dilation"] if ("dilation" in kwargs) else 1
ceil_mode = kwargs["ceil_mode"] if ("ceil_mode" in kwargs) else False
if padding == 0:
padding = [0, 0]
if dilation == 1:
dilation = [1, 1]
if not stride:
stride = kernel_size
kernel_size = pad_listlike(kernel_size, 2)
stride = pad_listlike(stride, 2)
padding = pad_listlike(padding, 2)
dilation = pad_listlike(dilation, 2)
assert len(kernel_size) == 2
assert len(stride) == 2
assert len(padding) == 2
assert len(dilation) == 2
computation_args = (
x,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
)
computation_args, _ = require_channels_last(computation_op, *computation_args)
counters["inductor"]["qmaxpool2d_matcher_count"] += 1
counters["inductor"]["qmaxpool2d_matcher_nodes"] += len(match.nodes)
return L[computation_op](*computation_args)
return qmaxpool2d
def _register_quantization_maxpool2d():
# Currently, the default parameters are not in FX Graph generated by Dynamo export.
# So, if user defines nn.MaxPool2d with different assignment of default parameter,
# it will generate graph with different number of input nodes and hence
# different pattern to be matched.
# Refer to the issue: https://github.com/pytorch/pytorch/issues/105901
max_pool2d_args_list = [
[
KeywordArg("stride"),
],
[
KeywordArg("stride"),
KeywordArg("padding"),
],
[
KeywordArg("stride"),
KeywordArg("padding"),
KeywordArg("dilation"),
],
[
KeywordArg("stride"),
KeywordArg("padding"),
KeywordArg("dilation"),
KeywordArg("ceil_mode"),
],
]
for max_pool2d_args in max_pool2d_args_list:
dequantize_maxpool2d_pattern = CallFunction(
aten.max_pool2d_with_indices.default,
get_dequantize_per_tensor_activation_pattern(),
KeywordArg("kernel_size"),
*max_pool2d_args,
)
dequantize_lowmem_maxpool2d_pattern = CallFunction(
prims._low_memory_max_pool2d_with_offsets.default,
get_dequantize_per_tensor_activation_pattern(),
KeywordArg("kernel_size"),
*max_pool2d_args,
KeywordArg("offset_dtype"),
)
dequantize_maxpool2d_get_item_pattern = CallFunction(
operator.getitem,
dequantize_maxpool2d_pattern,
Arg(),
)
dequantize_lowmem_maxpool2d_get_item_pattern = CallFunction(
operator.getitem,
dequantize_lowmem_maxpool2d_pattern,
Arg(),
)
_register_quantized_maxpool2d_lowering(
generate_pattern_with_output_quant(dequantize_maxpool2d_get_item_pattern),
quantized.max_pool2d.default,
)
_register_quantized_maxpool2d_lowering(
generate_pattern_with_output_quant(
dequantize_lowmem_maxpool2d_get_item_pattern
),
quantized.max_pool2d.default,
)
def _is_input_output_same_scale_zp(check_node):
def fn(match):
# Ensure all the inputs and output has same scale and zero point
# Step 1: Check inputs/output zero point
# Get dequant nodes at input
dequant_nodes = filter_nodes(
match.nodes, quantized_decomposed.dequantize_per_tensor.default
)
zero_points = [node.args[2] for node in dequant_nodes]
# Get quant nodes at output
quant_nodes = filter_nodes(
match.nodes, quantized_decomposed.quantize_per_tensor.default
)
assert len(quant_nodes) == 1, "expect only 1 add node at output quant pattern"
zero_points.append(quant_nodes[0].args[2])
if not all(zero_point == zero_points[0] for zero_point in zero_points):
return False
# Step 2: Check inputs/output scale
scales = [node.args[1] for node in dequant_nodes]
scales.append(quant_nodes[0].args[1])
if not all(math.isclose(scale, scales[0], rel_tol=1e-5) for scale in scales): # type: ignore[arg-type]
return False
return True
return fn
def _register_quantized_cat_lowering(
pattern,
computation_op,
):
@register_lowering_pattern(
pattern,
extra_check=_is_input_output_same_scale_zp(aten.cat.default),
)
def qcat(match: Match, inputs, dim, **kwargs):
# inputs is with format: [[x1, x1_dq_dtype, x1_zp, x1_scale], ...]
uint8_inputs = [input[0] for input in inputs]
counters["inductor"]["qcat_matcher_count"] += 1
counters["inductor"]["qcat_matcher_nodes"] += len(match.nodes)
return L[computation_op](uint8_inputs, dim)
return qcat
_raw_dequantize_per_tensor_activation_pattern = CallFunction(
quantized_decomposed.dequantize_per_tensor.default,
Arg(),
Arg(),
Arg(),
Arg(),
Arg(),
Arg(),
)
def _register_quantization_cat():
dequantize_cat_pattern = CallFunction(
aten.cat.default,
ListOf(_raw_dequantize_per_tensor_activation_pattern),
KeywordArg("dim"),
)
_register_quantized_cat_lowering(
generate_pattern_with_output_quant(dequantize_cat_pattern),
aten.cat,
)
def _register_quantized_reshape_lowering(
pattern,
computation_op,
):
@register_lowering_pattern(
pattern,
extra_check=_is_input_output_same_scale_zp(aten.reshape.default),
)
def qreshape(match: Match, *args, **kwargs):
qx = kwargs["x"]
shape = kwargs["shape"]
counters["inductor"]["qreshape_matcher_count"] += 1
counters["inductor"]["qreshape_matcher_nodes"] += len(match.nodes)
return L[computation_op](qx, shape)
return qreshape
def _register_quantization_reshape():
dequantize_reshape_pattern = CallFunction(
torch.ops.aten.reshape.default,
get_dequantize_per_tensor_activation_pattern(),
KeywordArg("shape"),
)
_register_quantized_reshape_lowering(
generate_pattern_with_output_quant(dequantize_reshape_pattern),
aten.reshape,
)
def _is_valid_woq_optimization_pattern():
def fn(match):
assert all(k in match.kwargs for k in ("x", "weight", "scales"))
x = match.kwargs["x"].meta["val"]
weight = match.kwargs["weight"].meta["val"]
scales = match.kwargs["scales"].meta["val"]
return (
# For now, we only support woq mm kernels
# with x.type=bfloat16 and w.type=int8
x.dtype == torch.bfloat16
and weight.dtype == torch.int8
and scales.dtype == torch.bfloat16
# _weight_int8pack_mm kernel only supports cpu now
# TODO: add cuda kernel support instead of calling mul+sum
and x.device.type == "cpu"
and x.device == weight.device
and x.device == scales.device
)
return fn
def _register_woq_lowering(pattern, computation_woq, computation_reshape):
@register_lowering_pattern(
pattern,
extra_check=_is_valid_woq_optimization_pattern(),
)
def woq(match: Match, *args, **kwargs):
x = kwargs["x"]
weight = kwargs["weight"]
scales = kwargs["scales"]
counters["inductor"]["woq_matcher_count"] += 1
counters["inductor"]["woq_matcher_nodes"] += len(match.nodes)
out_features = weight.get_size()[0]
origin_x_size = x.get_size()
x_shape = [-1, origin_x_size[-1]]
out_shape = origin_x_size[:-1] + [
out_features,
]
func1 = L[computation_reshape](x, x_shape)
func2 = L[computation_woq](func1, weight, scales)
return L[computation_reshape](func2, out_shape)
return woq
def _register_woq_mm_int8_pattern1():
# F.linear(x, weight.to(dtype=x.dtype)) * scales
# case of dispatching to mm, with x reshape
_woq_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.reshape.default,
CallFunction(
aten.mm.default,
CallFunction(aten.reshape.default, KeywordArg("x"), Arg()),
CallFunction(
aten.permute.default,
CallFunction(
prims.convert_element_type.default, KeywordArg("weight"), Arg()
),
Arg(),
),
),
Arg(),
),
KeywordArg("scales"),
)
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
def _register_woq_mm_int8_pattern2():
# F.linear(x, weight.to(dtype=x.dtype)) * scales
# case of dispatching to mm, w/o x reshape
_woq_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.reshape.default,
CallFunction(
aten.mm.default,
KeywordArg("x"),
CallFunction(
aten.permute.default,
CallFunction(
prims.convert_element_type.default, KeywordArg("weight"), Arg()
),
Arg(),
),
),
Arg(),
),
KeywordArg("scales"),
)
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
def _register_woq_mm_int8_pattern3():
# F.linear(x, weight.to(dtype=x.dtype)) * scales
# case of dispatching to bmm
_woq_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.bmm.default,
CallFunction(aten.expand.default, KeywordArg("x"), Arg()),
CallFunction(
aten.expand.default,
CallFunction(
aten.permute.default,
CallFunction(
prims.convert_element_type.default, KeywordArg("weight"), Arg()
),
Arg(),
),
Arg(),
),
),
KeywordArg("scales"),
)
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
def _register_woq_mm_int8_pattern4():
_woq_pattern = CallFunction(
aten.mul.Tensor,
CallFunction(
aten.mm.default,
KeywordArg("x"),
CallFunction(
prims.convert_element_type.default,
CallFunction(
aten.permute.default,
KeywordArg("weight"),
Arg(),
),
Arg(),
),
),
KeywordArg("scales"),
)
_register_woq_lowering(_woq_pattern, aten._weight_int8pack_mm.default, aten.reshape)
def _register_quantization_lowerings():
_register_quantization_unary_fusion()
_register_quantization_binary_fusion()
_register_quantization_maxpool2d()
_register_quantization_cat()
_register_quantization_reshape()
def _register_woq_lowerings():
_register_woq_mm_int8_pattern1()
_register_woq_mm_int8_pattern2()
_register_woq_mm_int8_pattern3()
_register_woq_mm_int8_pattern4()
def _is_valid_dequant_promotion_pattern(dtype=torch.float32):
def _inner(match):
assert dtype in [torch.float32, torch.bfloat16]
dequant_pattern_end_node = match.output_node()
if dequant_pattern_end_node.target not in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
prims.convert_element_type.default,
aten.reshape.default,
]:
return False
if dequant_pattern_end_node.target is aten.reshape.default:
dequant_node = (
dequant_pattern_end_node.args[
0
] # pattern: linear <- reshape <- dequant
if dtype == torch.float32
else dequant_pattern_end_node.args[0].args[
0
] # pattern: linear <- reshape <- to_bf16 <- dequant
)
else:
dequant_node = (
dequant_pattern_end_node # pattern: linear <- dequant
if dtype == torch.float32
else dequant_pattern_end_node.args[
0
] # pattern: linear <- to_bf16 <- dequant
)
if (
dequant_node.target
in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
]
and len(list(dequant_pattern_end_node.users)) > 1
):
# If dequant pattern has more than 1 users, then do dequant promoted
return True
return False
return _inner
def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
@register_freezing_graph_pattern(
pattern,
extra_check=_is_valid_dequant_promotion_pattern(dtype),
pass_number=pass_number,
)
def dequant_promotion(match: Match, *args, **kwargs):
# Dequant_promotion will transform
# graph 1:
# quant
# + - - - | - - - +
# | dequant |
# | / \ |
# | node1 node2 |
# + - | - - - | - +
# quant quant
# into:
# graph 2:
# quant
# + - - / - \ - - +
# |dequant dequant|
# | | | |
# | node1 node2 |
# + - | - - - | - +
# quant quant
# In graph 1, the dequant node is shared by node1 and node2,
# as a result, neither node1 nor node2 could form an int8
# fusion pattern.
# After this transformation, the graph 2 could hit the int8
# fusion pattern: dequant-node-quant, respectively for
# node1 and node2.
assert dtype in [torch.float32, torch.bfloat16]
def clone_to_new_node(graph, source_node, user_node):
# Clone the source_node to a new node
# Replace user_node's input from source_node to new_node
assert (
source_node.op == "call_function"
), "clone_to_new_node only support node.op call_function"
with graph.inserting_before(user_node):
new_node = graph.call_function(
source_node.target,
args=source_node.args,
kwargs=source_node.kwargs,
)
new_node.meta = copy.copy(source_node.meta)
user_node.replace_input_with(source_node, new_node)
return new_node
# Find the start node and end node of a dequant pattern
# * End node should be the match.output_node()
# * Start node should be the node of dequantize_per_tensor
dequant_pattern_end_node = match.output_node()
assert dequant_pattern_end_node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
prims.convert_element_type.default,
aten.reshape.default,
]
# For a dequant pattern, we should expect see the node list as:
# * OPT(aten.reshape.default)
# * OPT(prims.convert_element_type.default) (to_bf16)
# * dequantize_per_tensor
def _find_first_node_in_dequant_pattern(_node):
if _node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
]:
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
return _node
else:
assert (
len(_node.args) >= 1
), "In in dequant pattern, each node should have more than 1 arg."
return _find_first_node_in_dequant_pattern(_node.args[0])
dequant_pattern_start_node = _find_first_node_in_dequant_pattern(
dequant_pattern_end_node
)
assert dequant_pattern_start_node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
]
# Clone the dequant pattern for each user node
graph = match.graph
user_node_list = list(dequant_pattern_end_node.users)
for user_node in user_node_list[1:]:
_source_node = dequant_pattern_end_node
_user_node = user_node
while _source_node != dequant_pattern_start_node.args[0]:
_user_node = clone_to_new_node(graph, _source_node, _user_node)
_source_node = _source_node.args[0] # type: ignore[assignment]
counters["inductor"]["dequant_promotion_matcher_count"] += 1
counters["inductor"]["dequant_promotion_matcher_nodes"] += len(match.nodes)
def _is_valid_dequant_conv2d_pattern(dtype):
def _inner(match):
# Here we do some further check to ensure:
# 1. It's a conv2d node with dim of 4, since we only support lowering of conv2d now.
# 2. The dequant pattern has only 1 user of conv2d node.
# If these conditions don't meet, we will not
# insert weight prepack node into the matched pattern.
conv_node = match.output_node()
assert conv_node.target is aten.convolution.default
input_meta_value = conv_node.args[0].meta.get("val")
weight_meta_value = conv_node.args[1].meta.get("val")
for meta_value in [input_meta_value, weight_meta_value]:
if (
meta_value is None
or (meta_value.device.type != "cpu" and meta_value.device.type != "xpu")
or meta_value.dim() != 4
or (meta_value.device.type == "xpu" and match.kwargs["groups"] != 1)
):
# Only support conv2d now
# Grouped quantized convolution is not supported at XPU backend
return False
assert dtype in [torch.float32, torch.bfloat16]
if dtype == torch.float32:
dequant_node = conv_node.args[0]
else:
convert_to_bf16 = conv_node.args[0]
dequant_node = convert_to_bf16.args[0]
if len(list(dequant_node.users)) != 1:
# Ensure the dequant pattern only has 1 user
# since we will delete the dequant pattern here
return False
return True
return _inner
def _register_qconv_weight_prepack_pass(pattern, pass_number, dtype=torch.float32):
@register_freezing_graph_pattern(
pattern,
extra_check=_is_valid_dequant_conv2d_pattern(dtype),
pass_number=pass_number,
)
def qconv_weight_prepack(match: Match, *args, **kwargs):
"""
Match the pattern:
int8 activation
|
dequant_per_tensor
|
Conv2d <- optional(aten.clone.default) <- dequant_per_channel <- int8_weight
Insert weight prepack node and change the pattern to:
int8 activation
|
onednn.qconv2d_pointwise <- onednn.qconv_prepack <- int8_weight
"""
assert dtype in [torch.float32, torch.bfloat16]
conv_node = match.output_node()
assert conv_node.target is aten.convolution.default
if dtype == torch.float32:
dequant_node = conv_node.args[0]
else:
convert_to_bf16 = conv_node.args[0]
dequant_node = convert_to_bf16.args[0] # type: ignore[union-attr]
has_clone_to_channel_last_node_in_pattern = (
conv_node.args[1].target is aten.clone.default # type: ignore[union-attr]
)
clone_node = (
conv_node.args[1] if has_clone_to_channel_last_node_in_pattern else None
)
if dtype == torch.float32:
dequant_per_channel = (
clone_node.args[0] # type: ignore[union-attr]
if has_clone_to_channel_last_node_in_pattern
else conv_node.args[1]
)
else:
weight_to_bf16_node = (
clone_node.args[0] # type: ignore[union-attr]
if has_clone_to_channel_last_node_in_pattern
else conv_node.args[1]
)
dequant_per_channel = weight_to_bf16_node.args[0] # type: ignore[union-attr]
assert (
dequant_per_channel.target # type: ignore[union-attr]
is quantized_decomposed.dequantize_per_channel.default
)
# Activation QParams
qx, x_zp, x_scale = (
kwargs["x"],
kwargs["x_zp"],
kwargs["x_scale"],
)
# Weight QParams
qw, w_scale, w_zp = (
kwargs["q_weight"],
kwargs["w_scale"],
kwargs["w_zp"],
)
# Conv Params
bias, stride, padding, dilation, groups = (
kwargs["b"],
kwargs["stride"],
kwargs["padding"],
kwargs["dilation"],
kwargs["groups"],
)
x_shape = qx.meta.get("tensor_meta").shape
if has_free_symbols(x_shape):
# For dynamic shape case, we can't get activation shape ahead of runtime.
x_shape = None
graph = match.graph
with graph.inserting_before(conv_node):
# Insert weight prepack node and the QConv node
packed_weight_inputs = (
qw,
w_scale,
x_scale,
x_zp,
stride,
padding,
dilation,
groups,
x_shape,
)
packed_weight_op = torch.ops.onednn.qconv_prepack
prepack_weight_node = graph.call_function(
packed_weight_op, args=packed_weight_inputs
)
new_args: Tuple[Any, ...] = (
qx,
x_scale,
x_zp,
prepack_weight_node,
w_scale,
w_zp,
bias,
stride,
padding,
dilation,
groups,
1.0, # output_scale
0, # output_zero_point
dtype, # output_dtype
"none", # attr
[], # scalars
"", # algorithm
)
new_conv_node = graph.call_function(
torch.ops.onednn.qconv2d_pointwise.default, args=new_args
)
conv_node.replace_all_uses_with(new_conv_node)
new_conv_node.meta.update(conv_node.meta)
# Erase the original conv node
graph.erase_node(conv_node)
# Erase the dequant pattern
if dtype == torch.bfloat16:
graph.erase_node(convert_to_bf16) # type: ignore[possibly-undefined, arg-type]
graph.erase_node(dequant_node) # type: ignore[arg-type]
# Erase the dequant per channel pattern
if clone_node is not None:
graph.erase_node(clone_node) # type: ignore[arg-type]
if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined, arg-type]
graph.erase_node(dequant_per_channel) # type: ignore[arg-type]
counters["inductor"]["qconv2d_weight_prepack_matcher_count"] += 1
counters["inductor"]["qconv2d_weight_prepack_matcher_nodes"] += len(
match.nodes
)
def _generate_dequant_convolution_node_pattern(
_dequant_per_channel_pattern, dtype=torch.float32
):
assert dtype in [torch.float32, torch.bfloat16]
dequant_convolution_node_pattern = CallFunction(
aten.convolution.default,
_may_generate_pattern_with_dtype_convert(
get_dequantize_per_tensor_activation_pattern(),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
_dequant_per_channel_pattern,
KeywordArg("b"),
KeywordArg("stride"),
KeywordArg("padding"),
KeywordArg("dilation"),
KeywordArg("is_transposed"),
KeywordArg("out_padding"),
KeywordArg("groups"),
)
return dequant_convolution_node_pattern
def _generate_qconv_weight_prepack_patterns(dtype=torch.float32):
assert dtype in [torch.float32, torch.bfloat16]
return (
_generate_dequant_convolution_node_pattern(
dequantize_per_channel_weight_pattern
if dtype == torch.float32
else dequantize_per_channel_to_bf16_weight_pattern,
dtype,
),
# There is another pattern due to the pass of convert_conv_weights_to_channels_last
# https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/_inductor/freezing.py#L338-L362.
# Depend on some heuristics, it may or may not insert to(channel_last) node
# between convolution and dequant_per_channel node
_generate_dequant_convolution_node_pattern(
dequantize_per_channel_clone_weight_pattern
if dtype == torch.float32
else dequantize_per_channel_to_bf16_clone_weight_pattern,
dtype,
),
)
def _get_linear_node(match, input_dim_exceeds_two, input_contiguous):
output_reshape_node = None
if input_dim_exceeds_two:
if input_contiguous:
output_reshape_node = match.output_node()
assert output_reshape_node.target is aten.reshape.default
linear_node = output_reshape_node.args[0]
else:
linear_nodes = filter_nodes(match.nodes, aten.bmm.default)
assert len(linear_nodes) == 1
linear_node = linear_nodes[0]
else:
linear_node = match.output_node()
assert linear_node.target in (
aten.addmm.default,
aten.mm.default,
aten.bmm.default,
)
return linear_node, output_reshape_node
def _get_linear_dq_node(
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
):
act_reshape_node = None
activation_to_bf16_node = None
act_expand_node = None
if input_dim_exceeds_two:
if input_contiguous:
act_reshape_node = linear_node.args[input_index]
assert act_reshape_node.target is aten.reshape.default
if dtype == torch.float32:
# pattern: linear -> reshape -> dequant
dequant_node = act_reshape_node.args[0]
else:
# pattern: linear -> reshape -> to_bf16 -> dequant
activation_to_bf16_node = act_reshape_node.args[0]
dequant_node = activation_to_bf16_node.args[0]
else:
# bmm pattern decomposed from linear when input dim exceeds 2 and not contiguous
act_expand_node = linear_node.args[input_index]
assert act_expand_node.target is aten.expand.default
if dtype == torch.float32:
dequant_node = act_expand_node.args[0]
else:
activation_to_bf16_node = act_expand_node.args[0]
dequant_node = activation_to_bf16_node.args[0]
else:
if dtype == torch.float32:
# pattern: linear -> dequant
dequant_node = linear_node.args[input_index]
else:
# pattern: linear -> to_bf16 -> dequant
activation_to_bf16_node = linear_node.args[input_index]
dequant_node = activation_to_bf16_node.args[0]
return dequant_node, act_reshape_node, activation_to_bf16_node, act_expand_node
def _is_valid_dequant_linear_pattern(dtype, input_dim_exceeds_two, input_contiguous):
def _inner(match):
# Check dequant pattern has only 1 user.
(
linear_node,
_,
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
input_index = 1 if linear_node.target is aten.addmm.default else 0
assert dtype in [torch.float32, torch.bfloat16]
(
dequant_node,
_,
_,
_,
) = _get_linear_dq_node(
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
)
assert dequant_node.target in [
quantized_decomposed.dequantize_per_tensor.default,
quantized_decomposed.dequantize_per_tensor.tensor,
]
if len(list(dequant_node.users)) != 1:
# Ensure the dequant pattern only has 1 user
# since we will delete the dequant pattern here
return False
# Extra check for bmm pattern
if input_dim_exceeds_two and not input_contiguous:
# Check for act
# Act expand size should be exactly same as act size
act_expand_size = match.kwargs["act_expand_size"]
act_node = match.kwargs["x"]
if not (
hasattr(act_node, "meta")
and isinstance(act_node.meta.get("val", None), torch.Tensor)
and (act_node.meta["val"].size() == torch.Size(act_expand_size))
):
return False
# Check for wgt
# wgt permute dims should be [1, 0]
wgt_permute_dims = match.kwargs["permute_axes"]
if wgt_permute_dims != [1, 0]:
return False
# Check below wgt size items:
# wgt before expand should with dim 2
# Expand size should with dim 3
# Expand size[0] should same as act size[0]
# Expand size[1] should same as wgt size[1]
# Expand size[2] should same as wgt size[0]
qweight_node = match.kwargs["q_weight"]
wgt_expand_size = match.kwargs["wgt_expand_size"]
if not (
hasattr(qweight_node, "meta")
and isinstance(qweight_node.meta.get("val", None), torch.Tensor)
and len(qweight_node.meta["val"].size()) == 2
and len(wgt_expand_size) == 3
and wgt_expand_size[0] == act_node.meta["val"].size()[0]
and wgt_expand_size[1] == qweight_node.meta["val"].size()[1]
and wgt_expand_size[2] == qweight_node.meta["val"].size()[0]
):
return False
return True
return _inner
def _register_qlinear_weight_prepack_pass(
pattern,
pass_number,
dtype=torch.float32,
input_dim_exceeds_two=False,
input_contiguous=True,
):
@register_freezing_graph_pattern(
pattern,
extra_check=_is_valid_dequant_linear_pattern(
dtype, input_dim_exceeds_two, input_contiguous
),
pass_number=pass_number,
)
def qlinear_weight_prepack(match: Match, *args, **kwargs):
"""
Match the pattern:
int8 activation
|
dequant_per_tensor
|
mm/addmm <- t <- dequant_per_channel <- int8_weight
Insert weight prepack node and change the pattern to:
int8 activation
|
onednn.qlinear_pointwise <- onednn.qlinear_prepack <- int8_weight
"""
assert dtype in [torch.float32, torch.bfloat16]
(
linear_node,
output_reshape_node,
) = _get_linear_node(match, input_dim_exceeds_two, input_contiguous)
input_index = 1 if linear_node.target is aten.addmm.default else 0
weight_index = input_index + 1
(
dequant_node,
act_reshape_node,
activation_to_bf16_node,
act_expand_node,
) = _get_linear_dq_node(
linear_node, input_index, dtype, input_dim_exceeds_two, input_contiguous
)
if input_dim_exceeds_two and not input_contiguous:
wgt_expand_node = linear_node.args[weight_index]
assert wgt_expand_node.target is aten.expand.default
t_node = wgt_expand_node.args[0]
else:
t_node = linear_node.args[weight_index]
if dtype == torch.float32:
dequant_per_channel = t_node.args[0]
else:
weight_to_bf16_node = t_node.args[0]
dequant_per_channel = weight_to_bf16_node.args[0]
assert (
dequant_per_channel.target
is quantized_decomposed.dequantize_per_channel.default
)
# Activation QParams
qx, x_zp, x_scale = (
kwargs["x"],
kwargs["x_zp"],
kwargs["x_scale"],
)
# Weight QParams
qw, w_scale, w_zp = (
kwargs["q_weight"],
kwargs["w_scale"],
kwargs["w_zp"],
)
# Params
bias = kwargs["b"] if "b" in kwargs else None
x_shape = qx.meta.get("tensor_meta").shape
if has_free_symbols(x_shape):
# For dynamic shape case, we can't get activation shape ahead of runtime.
x_shape = None
graph = match.graph
with graph.inserting_before(linear_node):
# Insert weight prepack node and the qlinear node
packed_weight_inputs = (
qw,
x_shape,
)
packed_weight_op = torch.ops.onednn.qlinear_prepack
prepack_weight_node = graph.call_function(
packed_weight_op, args=packed_weight_inputs
)
new_args: Tuple[Any, ...] = (
qx,
x_scale,
x_zp,
prepack_weight_node,
w_scale,
w_zp,
bias,
1.0, # output_scale
0, # output_zero_point
dtype, # output_dtype
"none", # post op name
[], # post op args
"", # post op algorithm
)
Node = torch.fx.node.Node
if isinstance(x_scale, Node) and isinstance(x_zp, Node):
new_linear_node = graph.call_function(
torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
)
else:
new_linear_node = graph.call_function(
torch.ops.onednn.qlinear_pointwise.default, args=new_args
)
if input_dim_exceeds_two:
if input_contiguous:
output_reshape_node.replace_all_uses_with(new_linear_node)
new_linear_node.meta.update(output_reshape_node.meta)
else:
if bias:
output_add_node_for_bias = match.output_node()
assert output_add_node_for_bias.target is aten.add.Tensor
output_add_node_for_bias.replace_all_uses_with(new_linear_node)
new_linear_node.meta.update(output_add_node_for_bias.meta)
else:
linear_node.replace_all_uses_with(new_linear_node)
new_linear_node.meta.update(linear_node.meta)
else:
linear_node.replace_all_uses_with(new_linear_node)
new_linear_node.meta.update(linear_node.meta)
# Erase the original linear node
if input_dim_exceeds_two:
if input_contiguous:
graph.erase_node(output_reshape_node)
elif not input_contiguous and bias:
graph.erase_node(output_add_node_for_bias) # type: ignore[possibly-undefined]
graph.erase_node(linear_node)
if input_dim_exceeds_two:
if input_contiguous:
graph.erase_node(act_reshape_node)
else:
graph.erase_node(act_expand_node)
graph.erase_node(wgt_expand_node) # type: ignore[possibly-undefined]
if dtype == torch.bfloat16:
graph.erase_node(activation_to_bf16_node)
# Erase the dequant pattern
graph.erase_node(dequant_node)
# Erase the dequant per channel pattern
graph.erase_node(t_node)
if dtype == torch.bfloat16:
graph.erase_node(weight_to_bf16_node) # type: ignore[possibly-undefined]
graph.erase_node(dequant_per_channel)
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
match.nodes
)
def _generate_dequant_linear_node_pattern(
_dequant_per_channel_pattern,
dtype=torch.float32,
input_dim_exceeds_two=False,
is_tensor_overload=False,
):
assert dtype in [torch.float32, torch.bfloat16]
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
dequant_linear_bias_pattern = _may_generate_pattern_with_reshape(
CallFunction(
aten.addmm.default,
KeywordArg("b"),
_may_generate_pattern_with_reshape(
_may_generate_pattern_with_dtype_convert(
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
KeywordArg("act_reshape_size"),
input_dim_exceeds_two,
),
t_pattern,
),
KeywordArg("output_reshape_size"),
input_dim_exceeds_two,
)
dequant_linear_no_bias_pattern = _may_generate_pattern_with_reshape(
CallFunction(
aten.mm.default,
_may_generate_pattern_with_reshape(
_may_generate_pattern_with_dtype_convert(
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
KeywordArg("act_reshape_size"),
input_dim_exceeds_two,
),
t_pattern,
),
KeywordArg("output_reshape_size"),
input_dim_exceeds_two,
)
return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern
def _generate_dequant_bmm_node_pattern(
_dequant_per_channel_pattern,
dtype=torch.float32,
with_bias=False,
is_tensor_overload=False,
):
# When activation of linear dim exceed 2 and not contiguous
t_pattern = _generate_linear_t_pattern(_dequant_per_channel_pattern, dtype)
assert dtype in [torch.float32, torch.bfloat16]
dequant_bmm_pattern = CallFunction(
aten.bmm.default,
CallFunction(
aten.expand.default,
_may_generate_pattern_with_dtype_convert(
get_dequantize_per_tensor_activation_pattern(is_tensor_overload),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
KeywordArg("act_expand_size"),
),
CallFunction(
aten.expand.default,
t_pattern,
KeywordArg("wgt_expand_size"),
),
)
def _generate_pattern_with_output_add(_dequant_bmm_pattern, _with_bias):
if _with_bias:
return CallFunction(
aten.add.Tensor,
_dequant_bmm_pattern,
KeywordArg("b"),
)
else:
return _dequant_bmm_pattern
return _generate_pattern_with_output_add(dequant_bmm_pattern, with_bias)
def _generate_qlinear_weight_prepack_patterns(
dtype=torch.float32,
input_dim_exceeds_two=False,
input_contiguous=True,
with_bias=False,
is_tensor_overload=False,
):
if input_dim_exceeds_two and not input_contiguous:
return _generate_dequant_bmm_node_pattern(
dequantize_per_channel_weight_pattern,
dtype,
with_bias,
is_tensor_overload,
)
else:
return _generate_dequant_linear_node_pattern(
dequantize_per_channel_weight_pattern,
dtype,
input_dim_exceeds_two,
is_tensor_overload,
)
def _generate_linear_dynamic_fp16_pattern(
_dequant_weight_pattern,
input_dim_exceeds_two=False,
input_contiguous=True,
relu_fused=False,
):
dtype = torch.float32
t_pattern = _generate_linear_t_pattern(_dequant_weight_pattern, dtype)
if input_dim_exceeds_two and not input_contiguous:
# pattern is
# x -> expand -> bmm (-> add) (-> relu)
# w -> dequant -> permute -> expand /
pattern_no_bias = CallFunction(
aten.bmm.default,
CallFunction(
aten.expand.default,
KeywordArg("x"),
KeywordArg("act_expand_size"),
),
CallFunction(
aten.expand.default,
t_pattern,
KeywordArg("wgt_expand_size"),
),
)
pattern_with_bias = CallFunction(
aten.add.Tensor,
pattern_no_bias,
KeywordArg("b"),
)
if relu_fused:
pattern_with_bias = CallFunction(aten.relu.default, pattern_with_bias)
pattern_no_bias = CallFunction(aten.relu.default, pattern_no_bias)
return pattern_with_bias, pattern_no_bias
x_pattern_with_reshape = _may_generate_pattern_with_reshape(
KeywordArg("x"),
KeywordArg("act_reshape_size"),
input_dim_exceeds_two,
)
dequant_linear_bias_pattern = generate_pattern_with_unary(
_may_generate_pattern_with_reshape(
CallFunction(
aten.addmm.default,
KeywordArg("b"),
x_pattern_with_reshape,
t_pattern,
),
KeywordArg("output_reshape_size"),
input_dim_exceeds_two,
),
aten.relu.default if relu_fused else None,
)
dequant_linear_no_bias_pattern = generate_pattern_with_unary(
_may_generate_pattern_with_reshape(
CallFunction(
aten.mm.default,
x_pattern_with_reshape,
t_pattern,
),
KeywordArg("output_reshape_size"),
input_dim_exceeds_two,
),
aten.relu.default if relu_fused else None,
)
return dequant_linear_bias_pattern, dequant_linear_no_bias_pattern
def _register_dequant_promotion():
dequant_pattern_cases = itertools.product(
[torch.float32, torch.bfloat16], [True, False], [True, False]
)
for dtype, input_dim_exceeds_two, is_tensor_overload in dequant_pattern_cases:
# 4 dequantization patterns will be matched based on the dtype and input dimension size.
# Case 1: int8-mixed-fp32, input dim size is 2
# Case 2: int8-mixed-fp32, input dim size exceeds 2
# Case 3: int8-mixed-bf16, input dim size is 2
# Case 4: int8-mixed-bf16, input dim size exceeds 2
# quant
# + - - - - | - - - - +
# | dequant |
# | | |
# | OPT(to_bf16) |
# | | |
# | OPT(reshape) |
# | / \ |
# | node1 node2 |
# + - - | - - - | - - +
# OPT(reshape) OPT(reshape)
# + - - | - - - | - - +
# OPT(to_fp32) OPT(to_fp32)
# + - - | - - - | - - +
# quant quant
_register_dequant_promotion_pass(
_may_generate_pattern_with_reshape(
_may_generate_pattern_with_dtype_convert(
get_dequantize_per_tensor_activation_pattern(
is_tensor_overload=is_tensor_overload
),
KeywordArg("autocast_act_dtype"),
dtype == torch.bfloat16,
),
KeywordArg("act_reshape_size"),
with_reshape=input_dim_exceeds_two,
),
pass_number=0,
dtype=dtype,
) # pass_number=0 to run before weight prepack
def _register_qconv_weight_prepack():
for dtype in [torch.float32, torch.bfloat16]:
weight_prepack_patterns = _generate_qconv_weight_prepack_patterns(dtype)
for weight_prepack_pattern in weight_prepack_patterns:
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
_register_qconv_weight_prepack_pass(
weight_prepack_pattern, pass_number=1, dtype=dtype
)
def _register_qlinear_weight_prepack():
# 6 Linear related patterns will be matched based on the dtype, input dimension size and input contiguous.
# Then convert the pattern into a QLinear node with int8_fp32/bf16.
# Case 1: int8-mixed-fp32, input dim size is 2
# Case 2: int8-mixed-fp32, input dim size exceeds 2 and contiguous
# Case 3: int8-mixed-bf16, input dim size is 2
# Case 4: int8-mixed-bf16, input dim size exceeds 2 and contiguous
# + - - - - | - - - - - - | - - - - - +
# | dq_per_tensor dq_per_channel |
# | | | |
# | OPT(to_bf16) OPT(to_bf16) |
# | | | |
# | OPT(reshape) permute |
# | \ / |
# | addmm/mm |
# | | |
# | OPT(reshape) |
# Case 5: int8-mixed-fp32, input dim size exceeds 2 and not contiguous
# Case 6: int8-mixed-bf16, input dim size exceeds 2 and not contiguous
# + - - - - | - - - - - - | - - - - - +
# | dq_per_tensor dq_per_channel |
# | | | |
# | OPT(to_bf16) OPT(to_bf16) |
# | | | |
# | expand permute |
# | \ | |
# | expand |
# | / |
# | bmm |
# | | |
# | OPT(add) |
linear_weight_prepack_cases = itertools.product(
[torch.float32, torch.bfloat16], [True, False], [True, False]
)
# Step 1: register patterns from mm and addmm
for dtype, input_dim_exceeds_two, is_tensor_overload in linear_weight_prepack_cases:
weight_prepack_patterns = _generate_qlinear_weight_prepack_patterns(
dtype,
input_dim_exceeds_two,
is_tensor_overload=is_tensor_overload,
)
for weight_prepack_pattern in weight_prepack_patterns:
# Register to pass_number 1, so we can do dequant promotion in pass_number 0.
_register_qlinear_weight_prepack_pass(
weight_prepack_pattern,
pass_number=1,
dtype=dtype,
input_dim_exceeds_two=input_dim_exceeds_two,
)
# Step 2: register patterns from bmm
# Linear might be decomposed into bmm when input dim exceeds 2 and not contiguous
# refer to:
# https://github.com/pytorch/pytorch/blob/
# 80c07df659362a95da7cd4f3ec367abfdace38c4/torch/_decomp/decompositions.py#L3965-L3968
# in this case, we can convert it back to qlinear
for dtype, with_bias, is_tensor_overload in itertools.product(
[torch.float32, torch.bfloat16], [True, False], [True, False]
):
bmm_pattern = _generate_qlinear_weight_prepack_patterns(
dtype=dtype,
input_dim_exceeds_two=True,
input_contiguous=False,
with_bias=with_bias,
is_tensor_overload=is_tensor_overload,
)
_register_qlinear_weight_prepack_pass(
bmm_pattern,
pass_number=1
if with_bias
else 2, # if with_bias, there is an output add, so we should try to match it firstly
dtype=dtype,
input_dim_exceeds_two=True,
input_contiguous=False,
)
def _register_linear_dynamic_fp16_weight_prepack_pass(
pattern,
pass_number,
input_dim_exceeds_two=False,
input_contiguous=True,
relu_fused=False,
):
def _extra_check_fn(match: Match):
return match.kwargs["dtype_fp16"] == torch.float16
@register_freezing_graph_pattern(
pattern,
extra_check=_extra_check_fn,
pass_number=pass_number,
)
def linear_dynamic_fp16_weight_prepack(match: Match, *args, **kwargs):
"""
Match the pattern:
fp32 activation
|
mm/addmm <- t <- to_fp32 <- to_fp16 <- weight
|
(reshape) <- (relu)
OR
fp32 activation
|
expand
|
bmm <- expand <- t <- to_fp32 <- to_fp16 <- weight
|
(add) <- (relu)
Insert weight prepack node and change the pattern to:
fp32 activation
|
onednn.linear_dynamic_fp16 <- onednn.linear_prepack_fp16 <- weight
(or onednn.linear_relu_dynamic_fp16)
"""
# find params
x = kwargs["x"]
w = kwargs["w"]
bias = kwargs["b"] if "b" in kwargs else None
# find linear node
nodes_to_find = [aten.addmm.default, aten.mm.default, aten.bmm.default]
linear_nodes = []
for node in nodes_to_find:
linear_nodes.extend(filter_nodes(match.nodes, node))
assert len(linear_nodes) == 1
linear_node = linear_nodes[0]
assert isinstance(linear_node, torch.fx.node.Node)
input_index = 1 if linear_node.target is aten.addmm.default else 0
weight_index = input_index + 1
# find relu node
relu_node = None
if relu_fused:
relu_node = match.output_node()
assert isinstance(relu_node, torch.fx.node.Node)
# find reshape node, expand node and add node
(
act_reshape_node,
output_reshape_node,
expand_x_node,
expand_w_node,
add_bias_node,
) = (None, None, None, None, None)
t_node = None
if input_dim_exceeds_two:
if input_contiguous:
act_reshape_node = linear_node.args[input_index]
t_node = linear_node.args[weight_index]
output_reshape_node = next(iter(linear_node.users))
assert output_reshape_node.target is aten.reshape.default
else:
expand_x_node = linear_node.args[input_index]
expand_w_node = linear_node.args[weight_index]
assert isinstance(expand_w_node, torch.fx.node.Node)
t_node = expand_w_node.args[0]
if bias:
add_bias_node = next(iter(linear_node.users))
assert add_bias_node.target is aten.add.Tensor
else:
t_node = linear_node.args[weight_index]
assert isinstance(t_node, torch.fx.node.Node)
w_to_fp32_node = t_node.args[0]
assert (
isinstance(w_to_fp32_node, torch.fx.node.Node)
and w_to_fp32_node.target
is quantized_decomposed.convert_element_type.no_fuse
)
w_to_fp16_node = w_to_fp32_node.args[0]
assert (
isinstance(w_to_fp16_node, torch.fx.node.Node)
and w_to_fp16_node.target
is quantized_decomposed.convert_element_type.no_fuse
)
x_shape = x.meta.get("tensor_meta").shape
if has_free_symbols(x_shape):
# For dynamic shape case, we can't get activation shape ahead of runtime.
x_shape = None
graph = match.graph
with graph.inserting_before(linear_node):
# Insert weight prepack node and the qlinear node
packed_weight_inputs = (
w,
x_shape,
)
packed_weight_op = torch.ops.onednn.linear_prepack_fp16
prepack_weight_node = graph.call_function(
packed_weight_op, args=packed_weight_inputs
)
# create new linear node and insert on graph
new_args: Tuple[Any, ...] = (
x,
prepack_weight_node,
bias,
)
linear_op = (
torch.ops.onednn.linear_relu_dynamic_fp16.default
if relu_fused
else torch.ops.onednn.linear_dynamic_fp16.default
)
new_linear_node = graph.call_function(linear_op, args=new_args)
out_node = match.output_node()
out_node.replace_all_uses_with(new_linear_node)
# Erase the original nodes in the reverse order
new_linear_node.meta.update(out_node.meta)
if relu_node is not None:
graph.erase_node(relu_node)
if output_reshape_node is not None:
graph.erase_node(output_reshape_node)
if add_bias_node is not None:
graph.erase_node(add_bias_node)
graph.erase_node(linear_node)
if act_reshape_node is not None:
assert isinstance(act_reshape_node, torch.fx.node.Node)
graph.erase_node(act_reshape_node)
if expand_x_node is not None:
assert isinstance(expand_x_node, torch.fx.node.Node)
graph.erase_node(expand_x_node)
if expand_w_node is not None:
assert isinstance(expand_w_node, torch.fx.node.Node)
graph.erase_node(expand_w_node)
graph.erase_node(t_node)
graph.erase_node(w_to_fp32_node)
graph.erase_node(w_to_fp16_node)
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
match.nodes
)
def _register_linear_dynamic_fp16_weight_prepack():
to_dtype_op = torch.ops.quantized_decomposed.convert_element_type.no_fuse
weight_pattern = CallFunction(
to_dtype_op,
CallFunction(
to_dtype_op,
KeywordArg("w"),
KeywordArg("dtype_fp16"),
),
KeywordArg("dtype_fp32"),
)
cases = itertools.product(
[False, True], # input_dim_exceeds_two
[True, False], # input_contiguous
[False, True], # relu fused
)
for input_dim_exceeds_two, input_contiguous, relu_fused in cases:
patterns = _generate_linear_dynamic_fp16_pattern(
weight_pattern,
input_dim_exceeds_two,
input_contiguous,
relu_fused,
)
for pattern in patterns:
_register_linear_dynamic_fp16_weight_prepack_pass(
pattern,
pass_number=0 if relu_fused else 1,
input_dim_exceeds_two=input_dim_exceeds_two,
input_contiguous=input_contiguous,
relu_fused=relu_fused,
)
@functools.lru_cache(None)
def _register_quantization_weight_pack_pass():
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
_register_dequant_promotion()
# Step 2: QConv weight prepack
_register_qconv_weight_prepack()
# Step 3: QLinear weight prepack
_register_qlinear_weight_prepack()
_register_linear_dynamic_fp16_weight_prepack()
def quant_lift_up(graph_module: torch.fx.GraphModule):
"""
Lift up the quant node before view like nodes. It can benefit performance
of Attention like block. For example, we have the pattern as:
DQ
DQ LINEAR
LINEAR VIEW
VIEW PERMUTE
PERMUTE TRANSPOSE
Q Q
DQ DQ
Matmul
DIV
ADD
SOFTMAX
We want to lift up the the quant nodes from matmul before view like nodes
as the output of Linear node.
DQ
DQ LINEAR
LINEAR Q
Q VIEW
VIEW PERMUTE
PERMUTE TRANSPOSE
DQ DQ
Matmul
DIV
ADD
SOFTMAX
It produces a DQ->LINEAR->Q pattern which can be fused by backend.
"""
def is_view_op(node):
return node.op == "call_function" and node.target in _VIEW_OPS
for node in graph_module.graph.nodes:
# <TODO> Leslie: Here we verify that the quant node has exactly
# one input FX node, with constant scalar value for scale and zero point.
# For the case input of quant node has more than one input FX nodes,
# extend the implementation to lift up all the connected nodes
# before the view nodes to keep the topological order.
if (
node.op == "call_function"
and node.target in _PER_TENSOR_QUANTIZE_OPS
and len(node.all_input_nodes) == 1
and is_view_op(node.all_input_nodes[0])
):
quant_node = node
input_node_of_quant = quant_node.args[0]
# Check the nodes along lift up path has only 1 user node
# Propagate view like node to find where to insert the new quant node
could_lift_up = True
current_node = quant_node
input_node = current_node.args[0]
while is_view_op(input_node):
if len(input_node.users) != 1:
could_lift_up = False
break
current_node = input_node
input_node = current_node.args[0]
# Further check the input node of the first view node has only 1 user node
if could_lift_up and len(input_node.users) == 1:
# Replace dequant's input from quant to quant's input
quant_node.replace_all_uses_with(input_node_of_quant)
# Insert the new quant node
with graph_module.graph.inserting_before(current_node):
new_quant_node = graph_module.graph.node_copy(quant_node)
input_node.replace_all_uses_with(new_quant_node)
# Update inputs of new_quant_node
def maybe_replace_node(n: torch.fx.Node) -> torch.fx.Node:
if n == input_node_of_quant:
return input_node
else:
return n
new_args = map_arg(new_quant_node.args, maybe_replace_node)
new_kwargs = map_arg(new_quant_node.kwargs, maybe_replace_node)
new_quant_node.args = new_args # type: ignore[assignment]
new_quant_node.kwargs = new_kwargs # type: ignore[assignment]
graph_module.graph.erase_node(quant_node)
graph_module.graph.lint()
graph_module.recompile()
|