1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287 2288 2289 2290 2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324 2325 2326 2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361 2362 2363 2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453 2454 2455 2456 2457 2458 2459 2460 2461 2462 2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499 2500 2501 2502 2503 2504 2505 2506 2507 2508 2509 2510 2511 2512 2513 2514 2515 2516 2517 2518 2519 2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561 2562 2563 2564 2565 2566 2567 2568 2569 2570 2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595 2596 2597 2598 2599 2600 2601 2602 2603 2604 2605 2606 2607 2608 2609 2610 2611 2612 2613 2614 2615 2616 2617 2618 2619 2620 2621 2622 2623 2624 2625 2626 2627 2628 2629 2630 2631 2632 2633 2634 2635 2636 2637 2638 2639 2640 2641 2642 2643 2644 2645 2646 2647 2648 2649 2650 2651 2652 2653 2654 2655 2656 2657 2658 2659 2660 2661 2662 2663 2664 2665 2666 2667 2668 2669 2670 2671 2672 2673 2674 2675 2676 2677 2678 2679 2680 2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2704 2705 2706 2707 2708 2709 2710 2711 2712 2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778 2779 2780 2781 2782 2783 2784 2785 2786 2787 2788 2789 2790 2791 2792 2793 2794 2795 2796 2797 2798 2799 2800 2801 2802 2803 2804 2805 2806 2807 2808 2809 2810 2811 2812 2813 2814 2815 2816 2817 2818 2819 2820 2821 2822 2823 2824 2825 2826 2827 2828 2829 2830 2831 2832 2833 2834 2835 2836 2837 2838 2839 2840 2841 2842 2843 2844 2845 2846 2847 2848 2849 2850 2851 2852 2853 2854 2855 2856 2857 2858 2859 2860 2861 2862 2863 2864 2865 2866 2867 2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878 2879 2880 2881 2882 2883 2884 2885 2886 2887 2888 2889 2890 2891 2892 2893 2894 2895 2896 2897 2898 2899 2900 2901 2902 2903 2904 2905 2906 2907 2908 2909 2910 2911 2912 2913 2914 2915 2916 2917 2918 2919 2920 2921 2922 2923 2924 2925 2926 2927 2928 2929 2930 2931 2932 2933 2934 2935 2936 2937 2938 2939 2940 2941 2942 2943 2944 2945 2946 2947 2948 2949 2950 2951 2952 2953 2954 2955 2956 2957 2958 2959 2960 2961 2962 2963 2964 2965 2966 2967 2968 2969 2970 2971 2972 2973 2974 2975 2976 2977 2978 2979 2980 2981 2982 2983 2984 2985 2986 2987 2988 2989 2990 2991 2992 2993 2994 2995 2996 2997 2998 2999 3000 3001 3002 3003 3004 3005 3006 3007 3008 3009 3010 3011 3012 3013 3014 3015 3016 3017 3018 3019 3020 3021 3022 3023 3024 3025 3026 3027 3028 3029 3030 3031 3032 3033 3034 3035 3036 3037 3038 3039 3040 3041 3042 3043 3044 3045 3046 3047 3048 3049 3050 3051 3052 3053 3054 3055 3056 3057 3058 3059 3060 3061 3062 3063 3064 3065 3066 3067 3068 3069 3070 3071 3072 3073 3074 3075 3076 3077 3078 3079 3080 3081 3082 3083 3084 3085 3086 3087 3088 3089 3090 3091 3092 3093 3094 3095 3096 3097 3098 3099 3100 3101 3102 3103 3104 3105 3106 3107 3108 3109 3110 3111 3112 3113 3114 3115 3116 3117 3118 3119 3120 3121 3122 3123 3124 3125 3126 3127 3128 3129 3130 3131 3132 3133 3134 3135 3136 3137 3138 3139 3140 3141 3142 3143 3144 3145 3146 3147 3148 3149 3150 3151 3152 3153 3154 3155 3156 3157 3158 3159 3160 3161 3162 3163 3164 3165 3166 3167 3168 3169 3170 3171 3172 3173 3174 3175 3176 3177 3178 3179 3180 3181 3182 3183 3184 3185 3186 3187 3188 3189 3190 3191 3192 3193 3194 3195 3196 3197 3198 3199 3200 3201 3202 3203 3204 3205 3206 3207 3208 3209 3210 3211 3212 3213 3214 3215 3216 3217 3218 3219 3220 3221 3222 3223 3224 3225 3226 3227 3228 3229 3230 3231 3232 3233 3234 3235 3236 3237 3238 3239 3240 3241 3242 3243 3244 3245 3246 3247 3248 3249 3250 3251 3252 3253 3254 3255 3256 3257 3258 3259 3260 3261 3262 3263 3264 3265 3266 3267 3268 3269 3270 3271 3272 3273 3274 3275 3276 3277 3278 3279 3280 3281 3282 3283 3284 3285 3286 3287 3288 3289 3290 3291 3292 3293 3294 3295 3296 3297 3298 3299 3300 3301 3302 3303 3304 3305 3306 3307 3308 3309 3310 3311 3312 3313 3314 3315 3316 3317 3318 3319 3320 3321 3322 3323 3324 3325 3326 3327 3328 3329 3330 3331 3332 3333 3334 3335 3336 3337 3338 3339 3340 3341 3342 3343 3344 3345 3346 3347 3348 3349 3350 3351 3352 3353 3354 3355 3356 3357 3358 3359 3360 3361 3362 3363 3364 3365 3366 3367 3368 3369 3370 3371 3372 3373 3374 3375 3376 3377 3378 3379 3380 3381 3382 3383 3384 3385 3386 3387 3388 3389 3390 3391 3392 3393 3394 3395 3396 3397 3398 3399 3400 3401 3402 3403 3404 3405 3406 3407 3408 3409 3410 3411 3412 3413 3414 3415 3416 3417 3418 3419 3420 3421 3422 3423 3424 3425 3426 3427 3428 3429 3430 3431 3432 3433 3434 3435 3436 3437 3438 3439 3440 3441 3442 3443 3444 3445 3446 3447 3448 3449 3450 3451 3452 3453 3454 3455 3456 3457 3458 3459 3460 3461 3462 3463 3464 3465 3466 3467 3468 3469 3470 3471 3472 3473 3474 3475 3476 3477 3478 3479 3480 3481 3482 3483 3484 3485 3486 3487 3488 3489 3490 3491 3492 3493 3494 3495 3496 3497 3498 3499 3500 3501 3502 3503 3504 3505 3506 3507 3508 3509 3510 3511 3512 3513 3514 3515 3516 3517 3518 3519 3520 3521 3522 3523 3524 3525 3526 3527 3528 3529 3530 3531 3532 3533 3534 3535 3536 3537 3538 3539 3540 3541 3542 3543 3544 3545 3546 3547 3548 3549 3550 3551 3552 3553 3554 3555 3556 3557 3558 3559 3560 3561 3562 3563 3564 3565 3566 3567 3568 3569 3570 3571 3572 3573 3574 3575 3576 3577 3578 3579 3580 3581 3582 3583 3584 3585 3586 3587 3588 3589 3590 3591 3592 3593 3594 3595 3596 3597 3598 3599 3600 3601 3602 3603 3604 3605 3606 3607 3608 3609 3610 3611 3612 3613 3614 3615 3616 3617 3618 3619 3620 3621 3622 3623 3624 3625 3626 3627 3628 3629 3630 3631 3632 3633 3634 3635 3636 3637 3638 3639 3640 3641 3642 3643 3644 3645 3646 3647 3648 3649 3650 3651 3652 3653 3654 3655 3656 3657 3658 3659 3660 3661 3662 3663 3664 3665 3666 3667 3668 3669 3670 3671 3672 3673 3674 3675 3676 3677 3678 3679 3680 3681 3682 3683 3684 3685 3686 3687 3688 3689 3690 3691 3692 3693 3694 3695 3696 3697 3698 3699 3700 3701 3702 3703 3704 3705 3706 3707 3708 3709 3710 3711 3712 3713 3714 3715 3716 3717 3718 3719 3720 3721 3722 3723 3724 3725 3726 3727 3728 3729 3730 3731 3732 3733 3734 3735 3736 3737 3738 3739 3740 3741 3742 3743 3744 3745 3746 3747 3748 3749 3750 3751 3752 3753 3754 3755 3756 3757 3758 3759 3760 3761 3762 3763 3764 3765 3766 3767 3768 3769 3770 3771 3772 3773 3774 3775 3776 3777 3778 3779 3780 3781 3782 3783 3784 3785 3786 3787 3788 3789 3790 3791 3792 3793 3794 3795 3796 3797 3798 3799 3800 3801 3802 3803 3804 3805 3806 3807 3808 3809 3810 3811 3812 3813 3814 3815 3816 3817 3818 3819 3820 3821 3822 3823 3824 3825 3826 3827 3828 3829 3830 3831 3832 3833 3834 3835 3836 3837 3838 3839 3840 3841 3842 3843 3844 3845 3846 3847 3848 3849 3850 3851 3852 3853 3854 3855 3856 3857 3858 3859 3860 3861 3862 3863 3864 3865 3866 3867 3868 3869 3870 3871 3872 3873 3874 3875 3876 3877 3878 3879 3880 3881 3882 3883 3884 3885 3886 3887 3888 3889 3890 3891 3892 3893 3894 3895 3896 3897 3898 3899 3900 3901 3902 3903 3904 3905 3906 3907 3908 3909 3910 3911 3912 3913 3914 3915 3916 3917 3918 3919 3920 3921 3922 3923 3924 3925 3926 3927 3928 3929 3930 3931 3932 3933 3934 3935 3936 3937 3938 3939 3940 3941 3942 3943 3944 3945 3946 3947 3948 3949 3950 3951 3952 3953 3954 3955 3956 3957 3958 3959 3960 3961 3962 3963 3964 3965 3966 3967 3968 3969 3970 3971 3972 3973 3974 3975 3976 3977 3978 3979 3980 3981 3982 3983 3984 3985 3986 3987 3988 3989 3990 3991 3992 3993 3994 3995 3996 3997 3998 3999 4000 4001 4002 4003 4004 4005 4006 4007 4008 4009 4010 4011 4012 4013 4014 4015 4016 4017 4018 4019 4020 4021 4022 4023 4024 4025 4026 4027 4028 4029 4030 4031 4032 4033 4034 4035 4036 4037 4038 4039 4040 4041 4042 4043 4044 4045 4046 4047 4048 4049 4050 4051 4052 4053 4054 4055 4056 4057 4058 4059 4060 4061 4062 4063 4064 4065 4066 4067 4068 4069 4070 4071 4072 4073 4074 4075 4076 4077 4078 4079 4080 4081 4082 4083 4084 4085 4086 4087 4088 4089 4090 4091 4092 4093 4094 4095 4096 4097 4098 4099 4100 4101 4102 4103 4104 4105 4106 4107 4108 4109 4110 4111 4112 4113 4114 4115 4116 4117 4118 4119 4120 4121 4122 4123 4124 4125 4126 4127 4128 4129 4130 4131 4132 4133 4134 4135 4136 4137 4138 4139 4140 4141 4142 4143 4144 4145 4146 4147 4148 4149 4150 4151 4152 4153 4154 4155 4156 4157 4158 4159 4160 4161 4162 4163 4164 4165 4166 4167 4168 4169 4170 4171 4172 4173 4174 4175 4176 4177 4178 4179 4180 4181 4182 4183 4184 4185 4186 4187 4188 4189 4190 4191 4192 4193 4194 4195 4196 4197 4198 4199 4200 4201 4202 4203 4204 4205 4206 4207 4208 4209 4210 4211 4212 4213 4214 4215 4216 4217 4218 4219 4220 4221 4222 4223 4224 4225 4226 4227 4228 4229 4230 4231 4232 4233 4234 4235 4236 4237 4238 4239 4240 4241 4242 4243 4244 4245 4246 4247 4248 4249 4250 4251 4252 4253 4254 4255 4256 4257 4258 4259 4260 4261 4262 4263 4264 4265 4266 4267 4268 4269 4270 4271 4272 4273 4274 4275 4276 4277 4278 4279 4280 4281 4282 4283 4284 4285 4286 4287 4288 4289 4290 4291 4292 4293 4294 4295 4296 4297 4298 4299 4300 4301 4302 4303 4304 4305 4306 4307 4308 4309 4310 4311 4312 4313 4314 4315 4316 4317 4318 4319 4320 4321 4322 4323 4324 4325 4326 4327 4328 4329 4330 4331 4332 4333 4334 4335 4336 4337 4338 4339 4340 4341 4342 4343 4344 4345 4346 4347 4348 4349 4350 4351 4352 4353 4354 4355 4356 4357 4358 4359 4360 4361 4362 4363 4364 4365 4366 4367 4368 4369 4370 4371 4372 4373 4374 4375 4376 4377 4378 4379 4380 4381 4382 4383 4384 4385 4386 4387 4388 4389 4390 4391 4392 4393 4394 4395 4396 4397 4398 4399 4400 4401 4402 4403 4404 4405 4406 4407 4408 4409 4410 4411 4412 4413 4414 4415 4416 4417 4418 4419 4420 4421 4422 4423 4424 4425 4426 4427 4428 4429 4430 4431 4432 4433 4434 4435 4436 4437 4438 4439 4440 4441 4442 4443 4444 4445 4446 4447 4448 4449 4450 4451 4452 4453 4454 4455 4456 4457 4458 4459 4460 4461 4462 4463 4464 4465 4466 4467 4468 4469 4470 4471 4472 4473 4474 4475 4476 4477 4478 4479 4480 4481 4482 4483 4484 4485 4486 4487 4488 4489 4490 4491 4492 4493 4494 4495 4496 4497 4498 4499 4500 4501 4502 4503 4504 4505 4506 4507 4508 4509 4510 4511 4512 4513 4514 4515 4516 4517 4518 4519 4520 4521 4522 4523 4524 4525 4526 4527 4528 4529 4530 4531 4532 4533 4534 4535 4536 4537 4538 4539 4540 4541 4542 4543 4544 4545 4546 4547 4548 4549 4550 4551 4552 4553 4554 4555 4556 4557 4558 4559 4560 4561 4562 4563 4564 4565 4566 4567 4568 4569 4570 4571 4572 4573 4574 4575 4576 4577 4578 4579 4580 4581 4582 4583 4584 4585 4586 4587 4588 4589 4590 4591 4592 4593 4594 4595 4596 4597 4598 4599 4600 4601 4602 4603 4604 4605 4606 4607 4608 4609 4610 4611 4612 4613 4614 4615 4616 4617 4618 4619 4620 4621 4622 4623 4624 4625 4626 4627 4628 4629 4630 4631 4632 4633 4634 4635 4636 4637 4638 4639 4640 4641 4642 4643 4644 4645 4646 4647 4648 4649 4650 4651 4652 4653 4654 4655 4656 4657 4658 4659 4660 4661 4662 4663 4664 4665 4666 4667 4668 4669 4670 4671 4672 4673 4674 4675 4676 4677 4678 4679 4680 4681 4682 4683 4684 4685 4686 4687 4688 4689 4690 4691 4692 4693 4694 4695 4696 4697 4698 4699 4700 4701 4702 4703 4704 4705 4706 4707 4708 4709 4710 4711 4712 4713 4714 4715 4716 4717 4718 4719 4720 4721 4722 4723 4724 4725 4726 4727 4728 4729 4730 4731 4732 4733 4734 4735 4736 4737 4738 4739 4740 4741 4742 4743 4744 4745 4746 4747 4748 4749 4750 4751 4752 4753 4754 4755 4756 4757 4758 4759 4760 4761 4762 4763 4764 4765 4766 4767 4768 4769 4770 4771 4772 4773 4774 4775 4776 4777 4778 4779 4780 4781 4782 4783 4784 4785 4786 4787 4788 4789 4790 4791 4792 4793 4794 4795 4796 4797 4798 4799 4800 4801 4802 4803 4804 4805 4806 4807 4808 4809 4810 4811 4812 4813 4814 4815 4816 4817 4818 4819 4820 4821 4822 4823 4824 4825 4826 4827 4828 4829 4830 4831 4832 4833 4834 4835 4836 4837 4838 4839 4840 4841 4842 4843 4844 4845 4846 4847 4848 4849 4850 4851 4852 4853 4854 4855 4856 4857 4858 4859 4860 4861 4862 4863 4864 4865 4866 4867 4868 4869 4870 4871 4872 4873 4874 4875 4876 4877 4878 4879 4880 4881 4882 4883 4884 4885 4886 4887 4888 4889 4890 4891 4892 4893 4894 4895 4896 4897 4898 4899 4900 4901 4902 4903 4904 4905 4906 4907 4908 4909 4910 4911 4912 4913 4914 4915 4916 4917 4918 4919 4920 4921 4922 4923 4924 4925 4926 4927 4928 4929 4930 4931 4932 4933 4934 4935 4936 4937 4938 4939 4940 4941 4942 4943 4944 4945 4946 4947 4948 4949 4950 4951 4952 4953 4954 4955 4956 4957 4958 4959 4960 4961 4962 4963 4964 4965 4966 4967 4968 4969 4970 4971 4972 4973 4974 4975 4976 4977 4978 4979 4980 4981 4982 4983 4984 4985 4986 4987 4988 4989 4990 4991 4992 4993 4994 4995 4996 4997 4998 4999 5000 5001 5002 5003 5004 5005 5006 5007 5008 5009 5010 5011 5012 5013 5014 5015 5016 5017 5018 5019 5020 5021 5022 5023 5024 5025 5026 5027 5028 5029 5030 5031 5032 5033 5034 5035 5036 5037 5038 5039 5040 5041 5042 5043 5044 5045 5046 5047 5048 5049 5050 5051 5052 5053 5054 5055 5056 5057 5058 5059 5060 5061 5062 5063 5064 5065 5066 5067 5068 5069 5070 5071 5072 5073 5074 5075 5076 5077 5078 5079 5080 5081 5082 5083 5084 5085 5086 5087 5088 5089 5090 5091 5092 5093 5094 5095 5096 5097 5098 5099 5100 5101 5102 5103 5104 5105 5106 5107 5108 5109 5110 5111 5112 5113 5114 5115 5116 5117 5118 5119 5120 5121 5122 5123 5124 5125 5126 5127 5128 5129 5130 5131 5132 5133 5134 5135 5136 5137 5138 5139 5140 5141 5142 5143 5144 5145 5146 5147 5148 5149 5150 5151 5152 5153 5154 5155 5156 5157 5158 5159 5160 5161 5162 5163 5164 5165 5166 5167 5168 5169 5170 5171 5172 5173 5174 5175 5176 5177 5178 5179 5180 5181 5182 5183 5184 5185 5186 5187 5188 5189 5190 5191 5192 5193 5194 5195 5196 5197 5198 5199 5200 5201 5202 5203 5204 5205 5206 5207 5208 5209 5210 5211 5212 5213 5214 5215 5216 5217 5218 5219 5220 5221 5222 5223 5224 5225 5226 5227 5228 5229 5230 5231 5232 5233 5234 5235 5236 5237 5238 5239 5240 5241 5242 5243 5244 5245 5246 5247 5248 5249 5250 5251 5252 5253 5254 5255 5256 5257 5258 5259 5260 5261 5262 5263 5264 5265 5266 5267 5268 5269 5270 5271 5272 5273 5274 5275 5276 5277 5278 5279 5280 5281 5282 5283 5284 5285 5286 5287 5288 5289 5290 5291 5292 5293 5294 5295 5296 5297 5298 5299 5300 5301 5302 5303 5304 5305 5306 5307 5308 5309 5310 5311 5312 5313 5314 5315 5316 5317 5318 5319 5320 5321 5322 5323 5324 5325 5326 5327 5328 5329 5330 5331 5332 5333 5334 5335 5336 5337 5338 5339 5340 5341 5342 5343 5344 5345 5346 5347 5348 5349 5350 5351 5352 5353 5354 5355 5356 5357 5358 5359 5360 5361 5362 5363 5364 5365 5366 5367 5368 5369 5370 5371 5372 5373 5374 5375 5376 5377 5378 5379 5380 5381 5382 5383 5384 5385 5386 5387 5388 5389 5390 5391 5392 5393 5394 5395 5396 5397 5398 5399 5400 5401 5402 5403 5404 5405 5406 5407 5408 5409 5410 5411 5412 5413 5414 5415 5416 5417 5418 5419 5420 5421 5422 5423 5424 5425 5426 5427 5428 5429 5430 5431 5432 5433 5434 5435 5436 5437 5438 5439 5440 5441 5442 5443 5444 5445 5446 5447 5448 5449 5450 5451 5452 5453 5454 5455 5456 5457 5458 5459 5460 5461 5462 5463 5464 5465 5466 5467 5468 5469 5470 5471 5472 5473 5474 5475 5476 5477 5478 5479 5480 5481 5482 5483 5484 5485 5486 5487 5488 5489 5490 5491 5492 5493 5494 5495 5496 5497 5498 5499 5500 5501 5502 5503 5504 5505 5506 5507 5508 5509 5510 5511 5512 5513 5514 5515 5516 5517 5518 5519 5520 5521 5522 5523 5524 5525 5526 5527 5528 5529 5530 5531 5532 5533 5534 5535 5536 5537 5538 5539 5540 5541 5542 5543 5544 5545 5546 5547 5548 5549 5550 5551 5552 5553 5554 5555 5556 5557 5558 5559 5560 5561 5562 5563 5564 5565 5566 5567 5568 5569 5570 5571 5572 5573 5574 5575 5576 5577 5578 5579 5580 5581 5582 5583 5584 5585 5586 5587 5588 5589 5590 5591 5592
|
#include <ATen/PythonTorchFunctionTLS.h>
#include <c10/core/SafePyObject.h>
#include <c10/core/impl/PyInterpreter.h>
#define PY_SSIZE_T_CLEAN
#include <ATen/EmptyTensor.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <c10/util/flat_hash_map.h>
#include <torch/csrc/autograd/grad_mode.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#include <torch/csrc/dynamo/guards.h>
#include <torch/csrc/inductor/inductor_ops.h>
#include <torch/csrc/utils/disable_torch_function.h>
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_compat.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_symnode.h>
#include <torch/csrc/utils/pythoncapi_compat.h>
#include <torch/extension.h>
#include <torch/csrc/dynamo/debug_macros.h>
#ifdef USE_CUDA
#include <ATen/cuda/EmptyTensor.h>
#endif
#ifdef USE_XPU
#include <ATen/xpu/EmptyTensor.h>
#endif
#include <chrono>
#include <sstream>
#include <tuple>
#include <utility>
// Certain CPython data structures are defined in `.c` files in earlier Python
// versions, e.g., for TupleIteratorGetItemAccessor, we need a fast way to
// retrieve the underlying tuple and access the item. Before Python 3.12
// version, the data structure is in tupleobject.c file -
// https://github.com/python/cpython/blob/9afc6d102d16080535325f645849cd84eb04d57d/Objects/tupleobject.c#L1058-L1062
//
// To handle the older python versions, we manually copy the struct here and
// manually cast it to this new struct. For newer versions, the struct is
// included in the header file.
#if IS_PYTHON_3_12_PLUS
#define Py_BUILD_CORE
#include <internal/pycore_range.h> // _PyRangeIterObject
#include <internal/pycore_tuple.h> // _PyTupleIterObject
#undef Py_BUILD_CORE
#else
// Manually create _PyTupleIterObject struct
typedef struct {
PyObject_HEAD
Py_ssize_t it_index;
PyTupleObject* it_seq; /* Set to NULL when iterator is exhausted */
} _PyTupleIterObject;
// Copied from CPython, and given a unified name for different Python verions.
// https://github.com/python/cpython/blob/7f71003b222ad398713514c2b55d34dc05dba6bc/Objects/rangeobject.c#L765-L771
typedef struct {
PyObject_HEAD
// NOTE for Python 3.12+, `index` is removed, and `start` is updated in place
// instead, upon each `next(...)` call. See
// https://github.com/python/cpython/pull/27986
long index;
long start;
long step;
long len;
} _PyRangeIterObject;
#endif // IS_PYTHON_3_12_PLUS
namespace torch::dynamo {
// Macro to skip addition of duplicate guards like EQUALS_MATCH
#define SKIP_IF_GUARD_ALREADY_PRESENT(name) \
if (self.is_leaf_guard_present(name)) { \
return; \
} \
self.insert_leaf_guard(name);
TensorCheck::TensorCheck(
const LocalState& state,
PyTypeObject* pt,
const at::Tensor& v,
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)
: pytype(pt),
dispatch_key_(state.apply(v.key_set()).raw_repr()),
dtype_(v.dtype().toScalarType()),
device_index_(v.device().index()),
requires_grad_(v.requires_grad()),
sizes_(std::move(dynamic_dims_sizes)),
strides_(std::move(dynamic_dims_strides)),
dim_(static_cast<int64_t>(sizes_.size())) {
// TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should
// we just treat this as optional?
}
TensorCheck::TensorCheck(
const LocalState& state,
PyTypeObject* pt,
c10::DispatchKeySet dispatch_key_set,
at::ScalarType dtype,
at::DeviceIndex device_index,
bool requires_grad,
std::vector<std::optional<c10::SymInt>> dynamic_dims_sizes,
std::vector<std::optional<c10::SymInt>> dynamic_dims_strides)
: pytype(pt),
dispatch_key_(state.apply(dispatch_key_set).raw_repr()),
dtype_(dtype),
device_index_(device_index),
requires_grad_(requires_grad),
sizes_(std::move(dynamic_dims_sizes)),
strides_(std::move(dynamic_dims_strides)),
dim_(static_cast<int64_t>(sizes_.size())) {}
// See note in guards.py [Note - On Export Tensor Guards]
// Logic parallel to here must be maintained in python
bool TensorCheck::check(const LocalState& state, const at::Tensor& v) {
// In terms of a sparse_csr tensor, it does not support strides informatio
c10::SymIntArrayRef sym_strides(std::vector<SymInt>(v.ndimension(), -1));
bool does_not_support_stride = v.layout() == c10::kSparseCsr ||
v.layout() == c10::kSparseCsc || v.layout() == c10::kSparseBsc ||
v.layout() == c10::kSparseBsr;
if (!does_not_support_stride) {
sym_strides = v.sym_strides();
}
return check(
state,
v.key_set(),
v.dtype().toScalarType(),
v.device(),
v.sym_sizes(),
sym_strides,
v.requires_grad());
}
bool TensorCheck::check(
const LocalState& state,
const c10::DispatchKeySet& dispatch_key_set,
const at::ScalarType& dtype,
const c10::Device& device,
const c10::SymIntArrayRef& sym_sizes,
const c10::SymIntArrayRef& sym_strides,
const bool& requires_grad) {
if (dispatch_key_ != state.apply(dispatch_key_set).raw_repr() ||
dtype_ != dtype || device_index_ != device.index() ||
requires_grad_ != requires_grad) {
return false;
}
auto ndim = sym_sizes.size();
if (ndim != static_cast<size_t>(dim_)) {
return false;
}
const auto& sizes = sym_sizes;
const auto& strides = sym_strides;
for (auto i : c10::irange(ndim)) {
auto known_size = sizes_[i];
auto known_stride = strides_[i];
if (known_size.has_value()) {
if (known_size.value() != sizes[i]) {
return false;
}
}
if (known_stride.has_value()) {
if (known_stride.value() != strides[i]) {
return false;
}
}
}
return true;
}
std::string TensorCheck::check_verbose(
const LocalState& state,
const at::Tensor& v,
const std::string& tensor_name) {
std::stringstream fail_reason;
fail_reason << "tensor '" << tensor_name << "' ";
if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) {
// return fmt::format("tensor dispatch key mismatch. expected {}, actual
// {}", dispatch_key_, state.apply(v.key_set()).raw_repr());
fail_reason << "dispatch key set mismatch. expected "
<< c10::DispatchKeySet(c10::DispatchKeySet::RAW, dispatch_key_)
<< ", actual " << state.apply(v.key_set());
return fail_reason.str();
} else if (dtype_ != v.dtype().toScalarType()) {
// return fmt::format("tensor dtype mismatch. expected {}, actual {}",
// dtype_, v.dtype().toScalarType());
fail_reason << "dtype mismatch. expected " << dtype_ << ", actual "
<< v.dtype().toScalarType();
return fail_reason.str();
} else if (device_index_ != v.device().index()) {
fail_reason << "Tensor device index mismatch. Expected device index to be "
<< device_index_ << ", actual " << v.device().index();
return fail_reason.str();
} else if (requires_grad_ != v.requires_grad()) {
// return fmt::format("tensor requires_grad mismatch. expected {}",
// requires_grad_);
fail_reason << "requires_grad mismatch. expected requires_grad="
<< requires_grad_;
return fail_reason.str();
}
auto ndim = v.ndimension();
if (ndim != dim_) {
// return fmt::format("tensor rank mismatch. expected {}, actual {}",
// sizes_.size(), ndim);
fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual "
<< ndim;
return fail_reason.str();
}
const auto& sizes = v.sym_sizes();
for (auto i : c10::irange(ndim)) {
auto known_size = sizes_[i];
if (known_size.has_value() && (known_size.value() != sizes[i])) {
fail_reason << "size mismatch at index " << i << ". expected "
<< known_size.value() << ", actual " << sizes[i];
return fail_reason.str();
}
}
const bool supports_stride =
!v.is_sparse() && !at::sparse_csr::is_sparse_compressed(v);
if (supports_stride) {
const auto& strides = v.sym_strides();
for (auto i : c10::irange(ndim)) {
auto known_stride = strides_[i];
if (known_stride.has_value() && known_stride.value() != strides[i]) {
fail_reason << "stride mismatch at index " << i << ". expected "
<< known_stride.value() << ", actual " << strides[i];
return fail_reason.str();
}
}
}
return "";
}
namespace {
typedef std::vector<TensorCheck> ChecksList;
typedef struct {
PyObject_HEAD
ChecksList* checks;
} TensorGuards;
static void TensorGuards_dealloc(TensorGuards* self) {
if (self->checks != nullptr) {
delete self->checks;
self->checks = nullptr;
}
Py_TYPE(self)->tp_free((PyObject*)self);
}
static PyObject* TensorGuards_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwds) {
TensorGuards* self = (TensorGuards*)type->tp_alloc(type, 0);
if (self != nullptr) {
self->checks = new ChecksList();
}
return (PyObject*)self;
}
static std::vector<std::optional<c10::SymInt>> wrapIntegersInOptional(
const c10::SymIntArrayRef& intArray) {
std::vector<std::optional<c10::SymInt>> optVec(intArray.size());
std::transform(
intArray.begin(),
intArray.end(),
optVec.begin(),
[](const c10::SymInt& value) { return std::make_optional(value); });
return optVec;
}
static std::vector<std::optional<c10::SymInt>> pyListToVecOptInt(
PyObject* pyList) {
std::vector<std::optional<c10::SymInt>> vec;
Py_ssize_t size = PyList_Size(pyList);
for (Py_ssize_t i = 0; i < size; i++) {
PyObject* item = PyList_GetItem(pyList, i);
auto handle = py::handle(item);
if (item == Py_None) {
vec.emplace_back(std::nullopt);
} else if (torch::is_symint(handle)) {
vec.emplace_back(py::cast<c10::SymInt>(handle));
} else {
int64_t value = PyLong_AsLongLong(item);
if (value == -1 && PyErr_Occurred()) {
PyErr_SetString(
PyExc_TypeError,
"Size or stride list item is not a valid integer.");
TORCH_CHECK(false, "Size or stride list item is not a valid integer.");
}
vec.emplace_back(c10::SymInt(value));
}
}
return vec;
}
static std::vector<std::vector<std::optional<c10::SymInt>>> get_dynamic_dims(
PyObject* dynamic_dims_py) {
std::vector<std::vector<std::optional<c10::SymInt>>> per_tensor_dynamic_dims;
if (dynamic_dims_py != Py_None) {
Py_ssize_t size = PyList_Size(dynamic_dims_py);
for (Py_ssize_t i = 0; i < size; i++) {
PyObject* py_list = PyList_GetItem(dynamic_dims_py, i);
std::vector<std::optional<c10::SymInt>> vec = pyListToVecOptInt(py_list);
per_tensor_dynamic_dims.push_back(std::move(vec));
}
}
return per_tensor_dynamic_dims;
}
static int TensorGuards_init(
TensorGuards* self,
PyObject* args,
PyObject* kwds) {
if (!PyTuple_CheckExact(args)) {
PyErr_SetString(PyExc_TypeError, "expected tuple()");
return -1;
}
// Top level structure is List[List[Union[int, None]]]
PyObject* dynamic_dims_sizes_py =
PyDict_GetItemString(kwds, "dynamic_dims_sizes");
if (dynamic_dims_sizes_py == nullptr) {
PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_sizes=...");
return -1;
}
PyObject* dynamic_dims_strides_py =
PyDict_GetItemString(kwds, "dynamic_dims_strides");
if (dynamic_dims_strides_py == nullptr) {
PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_strides=...");
return -1;
}
// dynamic_dims_strides/sizes_py is None when dynamic_shapes=False - this is
// an optimization to avoid invoking .size()/.stride() in python needlessly
std::vector<std::vector<std::optional<c10::SymInt>>>
per_tensor_dynamic_dims_sizes = get_dynamic_dims(dynamic_dims_sizes_py);
std::vector<std::vector<std::optional<c10::SymInt>>>
per_tensor_dynamic_dims_strides =
get_dynamic_dims(dynamic_dims_strides_py);
auto& checks = *self->checks;
auto len = PyTuple_GET_SIZE(args);
checks.reserve(len);
LocalState state;
for (auto i : c10::irange(len)) {
PyObject* item = PyTuple_GET_ITEM(args, i);
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
return -1;
}
auto tensor = THPVariable_Unpack(item);
std::vector<std::optional<c10::SymInt>> tensor_dims_size =
per_tensor_dynamic_dims_sizes.empty()
? wrapIntegersInOptional(tensor.sym_sizes())
: per_tensor_dynamic_dims_sizes[i];
std::vector<std::optional<c10::SymInt>> tensor_dims_stride =
per_tensor_dynamic_dims_strides.empty()
? wrapIntegersInOptional(tensor.sym_strides())
: per_tensor_dynamic_dims_strides[i];
checks.emplace_back(
state,
Py_TYPE(item),
std::move(tensor),
std::move(tensor_dims_size),
std::move(tensor_dims_stride));
}
return 0;
}
PyObject* TensorGuards_check(
TensorGuards* self,
PyObject* args,
PyObject* kwargs) {
if (!PyTuple_CheckExact(args)) {
PyErr_SetString(PyExc_TypeError, "expected tuple()");
return nullptr;
}
auto& checks = *self->checks;
auto len = PyTuple_GET_SIZE(args);
// kwargs is just ignored here
if (static_cast<decltype(len)>(checks.size()) != len) {
PyErr_SetString(PyExc_TypeError, "wrong length");
return nullptr;
}
LocalState state;
// Note - all the tensors that make it to guards must be unique. Dynamo
// builder handles guarding for positive aliases (X is Y). However, we do not
// create guards for negative alias (X is not Y) as that is an N^2
// relationship. Instead, we rely on the uniqueness upstream to verify, at
// check_fn time (this function).
ska::flat_hash_map<PyObject*, std::nullptr_t> unique_tensors;
for (auto i : c10::irange(len)) {
PyObject* item = PyTuple_GET_ITEM(args, i);
if (Py_TYPE(item) != checks[i].pytype) {
Py_RETURN_FALSE;
}
auto insertion = unique_tensors.insert({item, nullptr});
if (!insertion.second) {
// Violates uniqueness
Py_RETURN_FALSE;
}
if (!checks[i].check(state, THPVariable_Unpack(item))) {
Py_RETURN_FALSE;
}
}
Py_RETURN_TRUE;
}
PyObject* TensorGuards_check_verbose(
TensorGuards* self,
PyObject* args,
PyObject* kwargs) {
if (!PyTuple_CheckExact(args)) {
PyErr_SetString(PyExc_TypeError, "expected tuple()");
return nullptr;
}
auto& checks = *self->checks;
auto len = PyTuple_GET_SIZE(args);
if (static_cast<decltype(len)>(checks.size()) != len) {
PyErr_SetString(PyExc_TypeError, "wrong length");
return nullptr;
}
PyObject* tensor_check_names_py =
PyDict_GetItemString(kwargs, "tensor_check_names");
if (tensor_check_names_py == nullptr) {
PyErr_SetString(PyExc_TypeError, "missing tensor_check_names kwarg");
return nullptr;
}
if (!PyList_Check(tensor_check_names_py)) {
PyErr_SetString(PyExc_TypeError, "tensor_check_names kwarg must be a list");
return nullptr;
}
auto names_size = PyList_Size(tensor_check_names_py);
if (names_size != static_cast<decltype(names_size)>(checks.size())) {
PyErr_SetString(
PyExc_TypeError,
"tensor_check_names should be the same size as # tensors");
return nullptr;
}
std::vector<std::string> tensor_check_names;
tensor_check_names.reserve(names_size);
for (auto i : c10::irange(names_size)) {
PyObject* value = PyList_GetItem(tensor_check_names_py, i);
if (!PyUnicode_Check(value)) {
PyErr_SetString(
PyExc_TypeError, "tensor_check_names must only contain strings");
return nullptr;
}
tensor_check_names.emplace_back(PyUnicode_AsUTF8(value));
}
LocalState state;
ska::flat_hash_map<PyObject*, std::nullptr_t> unique_tensors;
for (auto i : c10::irange(len)) {
PyObject* item = PyTuple_GET_ITEM(args, i);
if (Py_TYPE(item) != checks[i].pytype) {
std::stringstream fail_reason;
PyObject* type_str = PyObject_Str(PyObject_Type(item));
fail_reason << "expected type of '" << tensor_check_names[i]
<< "' to be a tensor type, ";
if (!type_str) {
fail_reason << "but found a different type";
} else {
fail_reason << "' but found " << PyUnicode_AsUTF8(type_str);
}
return Py_BuildValue("s", fail_reason.str().c_str());
}
auto insertion = unique_tensors.insert({item, nullptr});
if (!insertion.second) {
std::stringstream fail_reason;
fail_reason << "Duplicate tensor found where not expected! ";
fail_reason << tensor_check_names[i]
<< "should not alias to anything, but is aliased";
return Py_BuildValue("s", fail_reason.str().c_str());
}
std::string fail_reason = checks[i].check_verbose(
state, THPVariable_Unpack(item), tensor_check_names[i]);
if (fail_reason.length() > 0) {
return Py_BuildValue("s", fail_reason.c_str());
}
}
Py_RETURN_TRUE;
}
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
static PyMethodDef TensorGuards_methods[] = {
{"check",
(PyCFunction)(void*)TensorGuards_check,
METH_VARARGS | METH_KEYWORDS,
""},
{"check_verbose",
(PyCFunction)(void*)TensorGuards_check_verbose,
METH_VARARGS | METH_KEYWORDS,
"verbose fail reasons for failed checks"},
{nullptr} /* Sentinel */
};
static PyTypeObject TensorGuardsType = { PyVarObject_HEAD_INIT(nullptr, 0)
};
// TODO (janimesh) - Remove the PyObject_HEAD part when C++ guard manager is
// merged.
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
struct GlobalStateGuard {
PyObject_HEAD
inline void init() {
auto& ctx = at::globalContext();
_grad_mode = at::GradMode::is_enabled();
// The below two flags disambiguate
// if torch function disabled state is
// 1) enabled, 2) all disabled, 3) subclasses disabled
// we guard on the stack separately
_torch_function = torch::torch_function_enabled();
_torch_function_all_disabled = at::impl::torch_function_all_disabled();
_deterministic_algorithms = ctx.deterministicAlgorithms();
_deterministic_algorithms_warn_only = ctx.deterministicAlgorithmsWarnOnly();
_allow_tf32 = ctx.allowTF32CuBLAS();
_allow_fp16_reduce = ctx.allowFP16ReductionCuBLAS();
_allow_bf16_reduce = ctx.allowBF16ReductionCuBLAS();
_num_threads = at::get_num_threads();
_default_dtype = at::get_default_dtype();
}
inline bool check() const {
auto& ctx = at::globalContext();
return (_grad_mode == at::GradMode::is_enabled() &&
_torch_function == torch::torch_function_enabled() &&
_torch_function_all_disabled ==
at::impl::torch_function_all_disabled() &&
_deterministic_algorithms == ctx.deterministicAlgorithms() &&
_deterministic_algorithms_warn_only ==
ctx.deterministicAlgorithmsWarnOnly() &&
_allow_tf32 == ctx.allowTF32CuBLAS() &&
_allow_fp16_reduce == ctx.allowFP16ReductionCuBLAS() &&
_allow_bf16_reduce == ctx.allowBF16ReductionCuBLAS() &&
_num_threads == at::get_num_threads()) &&
_default_dtype == at::get_default_dtype();
}
inline std::string reason() const {
std::ostringstream os;
auto& ctx = at::globalContext();
if (_grad_mode != at::GradMode::is_enabled())
os << "grad_mode ";
if (_torch_function != torch::torch_function_enabled())
os << "torch_function ";
if (_deterministic_algorithms != ctx.deterministicAlgorithms())
os << "deterministic_algorithms ";
if (_deterministic_algorithms_warn_only !=
ctx.deterministicAlgorithmsWarnOnly())
os << "deterministic_algorithms_warn_only ";
if (_allow_tf32 != ctx.allowTF32CuBLAS())
os << "allow_tf32 ";
if (_allow_fp16_reduce != ctx.allowFP16ReductionCuBLAS())
os << "allow_fp16_reduce ";
if (_allow_bf16_reduce != ctx.allowBF16ReductionCuBLAS())
os << "allow_bf16_reduce ";
if (_num_threads != at::get_num_threads())
os << "num_threads ";
if (_default_dtype != at::get_default_dtype())
os << "default_dtype ";
return os.str();
}
bool _grad_mode;
bool _torch_function;
bool _torch_function_all_disabled;
bool _deterministic_algorithms;
bool _deterministic_algorithms_warn_only;
bool _allow_tf32;
bool _allow_fp16_reduce;
bool _allow_bf16_reduce;
int _num_threads;
caffe2::TypeMeta _default_dtype;
// TODO(jansel): we should guard on more state as inductor starts using it
};
int GlobalStateGuard_init(
GlobalStateGuard* self,
PyObject* args,
PyObject* kwargs) {
self->init();
return 0;
}
PyObject* GlobalStateGuard_check(
GlobalStateGuard* self,
PyObject* args,
PyObject* kwargs) {
if (self->check()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
PyObject* GlobalStateGuard_reason(
GlobalStateGuard* self,
PyObject* args,
PyObject* kwargs) {
return PyUnicode_FromString(self->reason().c_str());
}
// NOLINTNEXTLINE(*array*)
static PyMethodDef GlobalStateGuard_methods[] = {
{"check",
(PyCFunction)(void*)GlobalStateGuard_check,
METH_NOARGS,
"Return true if global state was the same as at creation time"},
{"reason",
(PyCFunction)(void*)GlobalStateGuard_reason,
METH_NOARGS,
"Return string reason for guard check failing"},
{nullptr}};
static PyTypeObject GlobalStateGuardType = { PyVarObject_HEAD_INIT(nullptr, 0)
};
static PyObject* check_type_id(PyObject* dummy, PyObject* args) {
// faster `lambda obj, expected: id(type(obj)) == expected`
PyObject* obj = nullptr;
unsigned long long expected = 0;
if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
return nullptr;
}
// NOLINTNEXTLINE(performance-no-int-to-ptr)
if (Py_TYPE(obj) == (void*)expected) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
// faster `lambda obj, expected: id(obj) == expected`
PyObject* obj = nullptr;
unsigned long long expected = 0;
if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
return nullptr;
}
// NOLINTNEXTLINE(performance-no-int-to-ptr)
if (obj == (void*)expected) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
#if IS_PYTHON_3_12_PLUS
static std::unordered_map<PyObject*, uint64_t> dict_version_map;
static int dict_version_watcher_id;
static uint64_t global_dict_version_id = 1;
static int dict_version_watch_callback(
PyDict_WatchEvent event,
PyObject* dict,
PyObject* key,
PyObject* new_value) noexcept {
if (event == PyDict_EVENT_DEALLOCATED) {
dict_version_map.erase(dict);
} else if (event != PyDict_EVENT_CLONED) {
dict_version_map[dict] = global_dict_version_id++;
}
return 0;
}
#endif
static uint64_t get_dict_version_unchecked(PyObject* dict) {
#if IS_PYTHON_3_12_PLUS
if (PyDict_Watch(dict_version_watcher_id, dict)) {
throw std::runtime_error("failed to add version watcher to dict!");
}
if (!dict_version_map.count(dict)) {
dict_version_map[dict] = global_dict_version_id++;
}
return dict_version_map[dict];
#else
return ((PyDictObject*)dict)->ma_version_tag;
#endif
}
static PyObject* dict_version(PyObject* dummy, PyObject* args) {
// Retrieves the version of a dictionary.
PyObject* obj = nullptr;
if (!PyArg_ParseTuple(args, "O", &obj)) {
return nullptr;
}
if (!PyDict_Check(obj)) {
return nullptr;
}
return THPUtils_packUInt64(get_dict_version_unchecked(obj));
}
static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
/*
Assert that a given tensor has a given size/stride, but ignore strides
of size==1 dimensions. Implemented in C++ as this is on the hot path.
*/
PyObject* item = nullptr;
PyObject* size = nullptr;
PyObject* stride = nullptr;
if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
return nullptr;
}
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
return nullptr;
}
if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
PyErr_SetString(PyExc_TypeError, "expected tuple()");
return nullptr;
}
at::Tensor tensor = THPVariable_Unpack(item);
int64_t ndim = tensor.ndimension();
if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
return nullptr;
}
std::stringstream msg;
int num_errors = 0;
for (auto i : c10::irange(ndim)) {
int64_t want_size = THPUtils_unpackLong(PyTuple_GET_ITEM(size, i));
int64_t want_stride = THPUtils_unpackLong(PyTuple_GET_ITEM(stride, i));
int64_t actual_size = tensor.size(i);
int64_t actual_stride = tensor.stride(i);
if (want_size != actual_size ||
// ignore stride differences when size is 1
(want_stride != actual_stride && actual_size > 1)) {
if (num_errors > 0)
msg << "; ";
msg << "expected size " << actual_size << "==" << want_size << ", stride "
<< actual_stride << "==" << want_stride << " at dim=" << i;
num_errors++;
}
}
if (num_errors) {
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
return nullptr;
}
Py_RETURN_TRUE;
}
template <typename T>
static void unwrap_size_tuple(PyObject* obj, T& output) {
TORCH_CHECK(PyTuple_CheckExact(obj));
size_t len = PyTuple_GET_SIZE(obj);
output.reserve(len);
for (size_t i = 0; i < len; ++i) {
auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(obj, i));
TORCH_CHECK(result >= 0);
output.emplace_back(result);
}
}
template <typename T>
static void _parse_empty_strided_args(
PyObject* args,
T& sizes,
T& strides,
at::ScalarType& dtype) {
TORCH_CHECK(PyTuple_CheckExact(args));
TORCH_CHECK(PyTuple_GET_SIZE(args) == 3);
// note PyTuple_GET_ITEM returns a borrowed ref, so no need for refcounts
unwrap_size_tuple(PyTuple_GET_ITEM(args, 0), sizes);
unwrap_size_tuple(PyTuple_GET_ITEM(args, 1), strides);
PyObject* py_dtype = PyTuple_GET_ITEM(args, 2);
TORCH_CHECK(THPDtype_Check(py_dtype));
dtype = reinterpret_cast<THPDtype*>(py_dtype)->scalar_type;
}
static PyObject* _empty_strided_device(
PyObject* dummy,
PyObject* args,
c10::DeviceType device_type) {
HANDLE_TH_ERRORS;
at::SmallVector<int64_t, 8> sizes;
at::SmallVector<int64_t, 8> strides;
at::ScalarType dtype{at::ScalarType::Undefined};
_parse_empty_strided_args(args, sizes, strides, dtype);
if (device_type == c10::DeviceType::CPU) {
return THPVariable_Wrap(
at::detail::empty_strided_cpu(sizes, strides, dtype));
}
#ifdef USE_CUDA
else if (device_type == c10::DeviceType::CUDA) {
return THPVariable_Wrap(at::detail::empty_strided_cuda(
sizes, strides, dtype, c10::DeviceType::CUDA));
}
#endif
#ifdef USE_XPU
else if (device_type == c10::DeviceType::XPU) {
return THPVariable_Wrap(at::detail::empty_strided_xpu(
sizes, strides, dtype, c10::DeviceType::XPU));
}
#endif
else {
TORCH_CHECK(
false, "PyTorch compiled without support for the specified device.");
}
END_HANDLE_TH_ERRORS;
}
static PyObject* _empty_strided_cpu(PyObject* dummy, PyObject* args) {
// at::empty_strided is surprising slow. This is a lower-overhead
// version that saves ~2us on every allocation.
return _empty_strided_device(dummy, args, c10::DeviceType::CPU);
}
static PyObject* _empty_strided_cuda(PyObject* dummy, PyObject* args) {
// at::empty_strided is surprising slow. This is lower-overhead.
return _empty_strided_device(dummy, args, c10::DeviceType::CUDA);
}
static PyObject* _empty_strided_xpu(PyObject* dummy, PyObject* args) {
// at::empty_strided is surprising slow. This is lower-overhead.
return _empty_strided_device(dummy, args, c10::DeviceType::XPU);
}
static PyObject* _reinterpret_tensor(PyObject* dummy, PyObject* args) {
HANDLE_TH_ERRORS;
static PythonArgParser parser(
{"_reinterpret_tensor(Tensor base, IntArrayRef sizes, IntArrayRef strides, int64_t offset_increment=0)"},
/*traceable=*/true);
ParsedArgs<4> parsed_args;
auto r = parser.parse(args, /*kwargs=*/nullptr, parsed_args);
Tensor self = r.tensor(0);
auto sizes = r.intlist(1);
auto strides = r.intlist(2);
auto offset_increment = r.toInt64(3);
auto res = torch::inductor::_reinterpret_tensor(
self, sizes, strides, offset_increment);
return torch::autograd::utils::wrap(res);
END_HANDLE_TH_ERRORS;
}
// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
static PyMethodDef _methods[] = {
{"check_type_id", check_type_id, METH_VARARGS, nullptr},
{"check_obj_id", check_obj_id, METH_VARARGS, nullptr},
{"assert_size_stride", assert_size_stride, METH_VARARGS, nullptr},
{"dict_version", dict_version, METH_VARARGS, nullptr},
{"_empty_strided_cpu", _empty_strided_cpu, METH_VARARGS, nullptr},
{"_empty_strided_cuda", _empty_strided_cuda, METH_VARARGS, nullptr},
{"_empty_strided_xpu", _empty_strided_xpu, METH_VARARGS, nullptr},
{"_reinterpret_tensor", _reinterpret_tensor, METH_VARARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
static struct PyModuleDef _module = {
PyModuleDef_HEAD_INIT,
"torch._C._dynamo.guards",
"Module containing checks on tensors",
-1,
_methods};
std::string get_exception_message() {
PyObject *ptype = nullptr, *pvalue = nullptr, *ptraceback = nullptr;
PyErr_Fetch(&ptype, &pvalue, &ptraceback);
PyObject* exc_message_pyobj = PyObject_Str(pvalue);
const char* exc_message = PyUnicode_AsUTF8(exc_message_pyobj);
Py_DECREF(exc_message_pyobj);
Py_XDECREF(ptype);
Py_XDECREF(pvalue);
Py_XDECREF(ptraceback);
return std::string(exc_message);
}
bool is_immutable_object(py::handle example_value) {
if (PyTuple_Check(example_value.ptr())) {
// Check that each element is immutable
for (Py_ssize_t i = 0; i < PyTuple_Size(example_value.ptr()); ++i) {
if (!is_immutable_object(
py::handle(PyTuple_GetItem(example_value.ptr(), i)))) {
return false;
}
}
return true;
}
return PyLong_Check(example_value.ptr()) ||
PyFloat_Check(example_value.ptr()) || PyBool_Check(example_value.ptr()) ||
PyUnicode_Check(example_value.ptr()) ||
THPVariable_Check(example_value.ptr());
}
bool is_parameter(py::handle tensor) {
py::object parameter = py::module::import("torch.nn").attr("Parameter");
return py::isinstance(tensor, parameter);
}
/**
* Dispatches metadata functions to the methods that return integer values,
* i.e. used whenever static shapes are being used.
*
* These are used by the tensor storage overlapping check. Even though their
* symbolic counterpart does work whenever static shapes are being used, the
* introduced overhead might significantly worsen the performance.
*/
struct StaticMeta {
static int64_t numel(const Tensor& t) {
return t.numel();
}
static int64_t storage_offset(const Tensor& t) {
return t.storage_offset();
}
static int64_t size(const Tensor& t, int64_t i) {
return t.size(i);
}
static int64_t stride(const Tensor& t, int64_t i) {
return t.stride(i);
}
};
/**
* Dispatches metadata functions to the methods that return c10::SymInt
* values, i.e. used whenever dynamic shapes are being used.
*/
struct DynamicMeta {
static SymInt numel(const Tensor& t) {
return t.sym_numel();
}
static SymInt storage_offset(const Tensor& t) {
return t.sym_storage_offset();
}
static SymInt size(const Tensor& t, int64_t i) {
return t.sym_size(i);
}
static SymInt stride(const Tensor& t, int64_t i) {
return t.sym_stride(i);
}
};
/**
* Assumption: x and y are known to share a storage, and we are trying to
* determine if their memory is actually completely disjoint, based on
* sizes/strides/storage_offset
*
* "Meta" should be one of the "*Meta" classes above. They dictate which
* version of the metadata functions we should be using (symbolic vs.
* concrete). Even though they have the same apparent behavior, the symbolic
* version introduces a bit of overhead. Such an overhead might end up
* becoming relevant if it's run enough times.
*/
template <class Meta>
bool tensors_definitely_do_not_overlap(const Tensor& x, const Tensor& y) {
if (x.is_same(y)) {
return false;
}
if (Meta::numel(x) == 0 || Meta::numel(y) == 0) {
return true;
}
// Make x always on the left
if (Meta::storage_offset(x) > Meta::storage_offset(y)) {
return tensors_definitely_do_not_overlap<Meta>(y, x);
}
// Short-circuit in the "obvious" overlapping case: both tensors are
// contiguous
if (x.is_contiguous() && y.is_contiguous()) {
if (Meta::storage_offset(x) + Meta::numel(x) > Meta::storage_offset(y)) {
// definitely overlap
return false;
} else {
// definitely no overlap
return true;
}
}
// Short-circuit: if last memory address of x is < start of y, then not
// overlapping.
auto x_last = Meta::storage_offset(x);
for (int64_t i = 0; i < x.dim(); i++) {
x_last += (Meta::size(x, i) - 1) * Meta::stride(x, i);
}
if (x_last < Meta::storage_offset(y)) {
return true;
}
if (x.dim() == 2 && y.dim() == 2 && Meta::stride(x, 1) == 1 &&
Meta::stride(y, 1) == 1) {
// This cases is needed for the shampoo optimizer.
// All tensors are 2d (non-contiguous), have the same outer stride, and have
// an inner stride of 1 (so rows are contiguous)
if (Meta::stride(x, 0) == Meta::stride(y, 0)) {
auto offset_delta = Meta::storage_offset(y) - Meta::storage_offset(x);
if (offset_delta < Meta::size(x, 1)) {
// definitely overlaps (row 0 of y overlaps with row 0 of x)
// Example:
// base = torch.arange(32).reshape(4, 8)
// x = base.narrow(1, 0, 4)
// x: size=(4, 4), stride=(8, 1), offset=0
// y = base.narrow(1, 3, 4)
// y: size=(4, 4), stride=(8, 1), offset=3
return false;
}
auto x_total_elems_covered =
Meta::stride(x, 0) * (Meta::size(x, 0) - 1) + Meta::size(x, 1);
if (x_total_elems_covered <= offset_delta) {
// definitely does not overlap (last byte of x is before start of y)
// Example:
// x: size=(4, 4), stride=(8, 1), offset=0 (last byte is 27)
// y: size=(4, 4), stride=(8, 1), offset=28 (start byte is 28)
return true;
}
// At this point, we want to check if the 0th row of y
// overlaps with **some** row of x.
// We can check this by shifting y backward by the shared stride,
// repeatedly, until the first row of y is before the first row of x. Then
// we can check if these rows overlap. We can accomplish this by modding
// our offset by the stride.
auto offset_delta_mod = offset_delta % Meta::stride(x, 0);
// Example:
// 0 1 2 3
// 9 10 11 12
// 18 19 20 21
// 27 28 29 30
// x: size=(4, 4), stride=(9, 1), offset=0
// y: size=(4, 4), stride=(9, 1), offset=22 (this would not overlap)
// y: size=(4, 4), stride=(9, 1), offset=23 (this would not overlap)
// y: size=(4, 4), stride=(9, 1), offset=24 (this would overlap)
// y: size=(4, 4), stride=(9, 1), offset=25 (this would overlap)
// If the interval [modded_offset, modded_offset + x_size] falls entirely
// without
if (offset_delta_mod + Meta::size(y, 1) <= Meta::stride(x, 0)) {
return true;
}
}
}
return false;
}
/**
* Computes the indices of the tensors that might overlap.
*
* Checks which of the given tensors have overlapping storages with ANY of
* the other tensors.
*
* So, for example, if tensor 1 overlaps with tensor 2, and tensor 3 with
* tensor 4, all of them will be in the output of this function. Even if
* tensor 1 and 4 don't overlap.
*/
template <class Meta>
std::unordered_set<int64_t> compute_overlapping_tensors(
const std::vector<Tensor>& tensors) {
std::unordered_set<int64_t> aliased_tensor_indices;
for (int64_t i = 0; i < static_cast<int64_t>(tensors.size()); i++) {
auto tensor_i = tensors[i];
for (int64_t j = 0; j < i; j++) {
if (!tensors_definitely_do_not_overlap<Meta>(tensor_i, tensors[j])) {
aliased_tensor_indices.insert(i);
aliased_tensor_indices.insert(j);
}
}
}
return aliased_tensor_indices;
}
/**
* Checks whether the storage overlapping relation is preserved.
*
* At this point, `non_overlapping` represents the tensors that should not
* have overlapping storages. Similarly, `overlapping` represents the tensors
* that should have overlapping storage in some way (or that we can't be sure).
*
* This function checks whether the assumption above is true or not.
*/
bool check_overlapping(
const std::vector<Tensor>& overlapping,
const std::vector<Tensor>& non_overlapping) {
// Merge the tensor lists.
std::vector<Tensor> tensors;
tensors.reserve(overlapping.size() + non_overlapping.size());
tensors.insert(tensors.end(), overlapping.begin(), overlapping.end());
tensors.insert(tensors.end(), non_overlapping.begin(), non_overlapping.end());
// Check what is the current storage overlapping relation.
auto indices = compute_overlapping_tensors<StaticMeta>(tensors);
// Check that the set of indices of tensors that might overlap is equal to
// the indices of the first `overlapping.size()` tensors. That's because
// `overlapping` tensors were in the beginning of `tensors` list.
auto range = c10::irange(overlapping.size());
return indices.size() == overlapping.size() &&
std::all_of(range.begin(), range.end(), [&](int64_t i) {
return indices.count(i) == 1;
});
}
/**
* Class responsible for collecting and checking the storage overlap relations.
*
* The way GuardManager is implemented, when STORAGE_OVERLAPPING guard check is
* run on a given tensor, we don't know if it is an overlapping or
* non-overlapping tensor. There's no order to which GuardManager runs the guard
* check so that we can split it in 2.
*
* Since we are only interested in the classification of each tensor (not
* necessarily the order), we can just issue 2 STORAGE_OVERLAPPING guards
* representing the overlapping tensors and the non-overlapping ones.
*
* In order to collect the information from both guards (so that we can call
* `check_overlapping` function correctly), we need this class which stores
* both kinds of tensors, and knows when it has collected each one of them.
*/
class StorageOverlapChecker {
public:
StorageOverlapChecker(
size_t expected_overlapping,
size_t expected_non_overlapping)
: _expected_overlapping(expected_overlapping),
_expected_non_overlapping(expected_non_overlapping) {}
/**
* Adds a tensor to the corresponding storage, based on whether it should be
* an `overlapping` tensor or not.
*/
void add(PyObject* obj, bool overlapping) {
// Just check that `obj` is actually a tensor, so that we can keep it alive
// by incrementing its ref-count.
TORCH_CHECK(THPVariable_CheckExact(obj) || THPVariable_Check(obj));
Py_INCREF(obj);
_get(overlapping).push_back(obj);
}
void reset(bool overlapping) {
auto& vec = _get(overlapping);
for (auto item : vec) {
Py_DECREF(item);
}
vec.clear();
}
/**
* Maybe checks the storage overlapping relation.
*
* Before actually calling `check_overlapping` function, this function makes
* sure it has collected all expected tensors.
*/
bool maybe_check() {
TORCH_CHECK(_expected_overlapping >= _overlapping.size());
TORCH_CHECK(_expected_non_overlapping >= _non_overlapping.size());
if (_expected_overlapping == _overlapping.size() &&
_expected_non_overlapping == _non_overlapping.size()) {
// Transform each list of PyObject* into an actual list of Tensors.
auto overlapping_tensors =
_tensors_from(_overlapping, _expected_overlapping);
auto non_overlapping_tensors =
_tensors_from(_non_overlapping, _expected_non_overlapping);
return check_overlapping(overlapping_tensors, non_overlapping_tensors);
} else {
// If we haven't collected them all yet, keep on running.
return true;
}
}
private:
/**
* Returns a reference to the container that corresponds to the given
* overlapping relation.
*/
std::vector<PyObject*>& _get(bool overlapping) {
return overlapping ? _overlapping : _non_overlapping;
}
/**
* Transforms a given list of PyObject* into a list of Tensor.
*/
std::vector<Tensor> _tensors_from(
const std::vector<PyObject*>& objects,
int64_t size) {
std::vector<Tensor> tensors;
tensors.reserve(size);
std::transform(
objects.begin(),
objects.end(),
std::back_inserter(tensors),
[=](PyObject* obj) { return THPVariable_Unpack(obj); });
return tensors;
}
// Expected number of possibly overlapping tensors.
size_t _expected_overlapping;
// Expected number of non-overlapping tensors.
size_t _expected_non_overlapping;
// Collected possibly overlapping tensors.
std::vector<PyObject*> _overlapping;
// Collected non-overlapping tensors.
std::vector<PyObject*> _non_overlapping;
};
/**
* Stores relevant guard debug information, e.g., failure str for a LeafGuard
* failure. The data structure is also accessible in Python.
*/
class GuardDebugInfo {
public:
GuardDebugInfo(
bool result,
py::list verbose_code_parts,
int num_guards_executed)
: result(result),
verbose_code_parts(std::move(verbose_code_parts)),
num_guards_executed(num_guards_executed) {}
// This constructor is used when guard succeeds.
GuardDebugInfo(bool result, int num_guards_executed)
: result(result), num_guards_executed(num_guards_executed) {}
GuardDebugInfo(
bool result,
const std::string& failed_reason,
int num_guards_executed)
: GuardDebugInfo(result, num_guards_executed) {
verbose_code_parts.append(failed_reason);
}
std::string to_string() {
std::stringstream ss;
ss << "GuardDebugInfo(\n"
<< "result=" << result << ",\n"
<< "verbose_code_parts=" << verbose_code_parts << ",\n"
<< "num_guards_executed=" << num_guards_executed << ")\n";
return ss.str();
}
// Whether the guard passed or failed.
bool result;
// This is a list of verbose_code_parts for the failed guard. When there are
// more than one verbose_code_parts, then recompilation reasoning infra on the
// Python side can iterate over this list and eval each string to pinpoint the
// exact code part that failed.
py::list verbose_code_parts;
// Total number of executed guards so far. This is helpful in debugging if
// shuffling is working.
int num_guards_executed;
};
class GuardManager;
class RootGuardManager;
class DictGuardManager;
/**
* Base class for the leaf guard in the GuardManager hierarchy.
*/
class LeafGuard {
public:
// Most guards do not need root guard manager.
LeafGuard(py::object verbose_code_parts)
: _verbose_code_parts(std::move(verbose_code_parts)) {}
// Guards like TENSOR_MATCH require root_guard_manager to access local_state
// shared across all leaf guards.
LeafGuard(RootGuardManager* root_guard_manager, py::object verbose_code_parts)
: _root_guard_manager(root_guard_manager),
_verbose_code_parts(std::move(verbose_code_parts)) {}
// check function could be called from python. This is useful for debugging
// purpose.
bool check(py::handle value) {
return check_nopybind(value.ptr());
}
GuardDebugInfo check_verbose(py::handle value) {
return check_verbose_nopybind(value.ptr());
}
virtual GuardDebugInfo check_verbose_nopybind(
PyObject* value) { // borrowed ref
bool result = check_nopybind(value);
if (!result) {
return GuardDebugInfo(result, _verbose_code_parts, 0);
}
return GuardDebugInfo(true, 0);
}
py::list verbose_code_parts() {
return _verbose_code_parts;
}
// This is on the hot path and avoids any refcounting code from pybind. This
// is not exposed to Python and can only be called from C++.
virtual bool check_nopybind(PyObject* value) = 0;
virtual ~LeafGuard() = default;
protected:
// RootGuardManager has state that is common across all guards like
// LocalState.
RootGuardManager* _root_guard_manager{nullptr};
private:
// This is set while constructing the leaf guard. This is used for identifying
// the cause of recompilation.
py::list _verbose_code_parts;
};
/**
* Represents a leaf guard that accepts the python guard check function. We
* would like to have most of the guards in C++ (to avoid a Python function
* call). But, it will take some time to reach that goal. Also, there might be
* cases where its too tedious to write an equivalent C++ guard.
*
* LAMBDA_GUARD allows us to gradually move to C++. We can start from all
* guards of type PythonLambaGuard and incrementally move expensive guards to
* C++.
*/
class LAMBDA_GUARD : public LeafGuard {
public:
LAMBDA_GUARD(py::object guard_check_fn, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)) {
if (py::isinstance<py::function>(guard_check_fn)) {
_guard_check_fn = py::cast<py::function>(std::move(guard_check_fn));
} else {
throw py::type_error("LAMBDA_GUARD expects (callable, str)");
}
}
// Runs the lambda function with the current f_locals value.
bool check_nopybind(PyObject* value) override { // borrowed ref
PyObject* x = PyObject_CallOneArg(_guard_check_fn.ptr(), value); // new ref
if (x == nullptr) {
// An exception is caught in the lambda function.
PyErr_Clear();
return false;
}
bool result = PyObject_IsTrue(x);
Py_DECREF(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
PyObject* x = PyObject_CallOneArg(_guard_check_fn.ptr(), value); // new ref
if (x == nullptr) {
// An exception is caught in the lambda function.
std::string exc_message = get_exception_message();
PyErr_Clear();
return GuardDebugInfo(false, exc_message, 0);
}
bool result = PyObject_IsTrue(x);
Py_DECREF(x);
if (result) {
return GuardDebugInfo(true, 0);
}
return GuardDebugInfo(false, verbose_code_parts(), 0);
}
private:
// The user provided lambda function for check_fn.
py::function _guard_check_fn;
};
class TYPE_MATCH : public LeafGuard {
public:
// type_id = id(type(obj))
TYPE_MATCH(py::object type_id, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_expected(py::cast<intptr_t>(std::move(type_id))) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return Py_TYPE(value) == (void*)_expected;
}
private:
// id of the type of the original object.
intptr_t _expected;
};
class ID_MATCH : public LeafGuard {
public:
// obj_id = id(obj)
ID_MATCH(py::object obj_id, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_expected(py::cast<intptr_t>(std::move(obj_id))) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
// NOLINTNEXTLINE(performance-no-int-to-ptr)
return value == (void*)_expected;
}
private:
// id of the original object.
intptr_t _expected;
};
class EQUALS_MATCH : public LeafGuard {
public:
EQUALS_MATCH(py::object value, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_value(value),
_value_type(Py_TYPE(value.ptr())) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
// Fast path - pointer equality check. Pointer equality checks are ok
// because objects guarded with EQUALS_MATCH are immutable.
if (value != _value.ptr()) {
// Check type
if (Py_TYPE(value) != _value_type) {
return false;
}
int result = PyObject_RichCompareBool(value, _value.ptr(), Py_EQ);
// Check for exception
if (result == -1) {
PyErr_Clear();
return false;
}
return result;
}
return true;
}
private:
// value to compare against. This is py::object so that we hold on to the
// original value and prevent garbage collection. We run EQUALS_MATCH only on
// selected objects which do not have high memory footprint, so holding on to
// these objects is ok.
py::object _value;
// Type of the value
PyTypeObject* _value_type;
};
class RANGE_ITERATOR_MATCH : public LeafGuard {
public:
RANGE_ITERATOR_MATCH(
py::object start,
py::object stop,
py::object step,
py::object type_id,
py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_type_id(py::cast<intptr_t>(std::move(type_id))) {
PyObject* start_obj = start.ptr();
PyObject* stop_obj = stop.ptr();
PyObject* step_obj = step.ptr();
_start = THPUtils_unpackLong(start_obj);
_stop = THPUtils_unpackLong(stop_obj);
_step = THPUtils_unpackLong(step_obj);
TORCH_CHECK(
!PyErr_Occurred(), "values of start/stop/step must fit in a long type");
}
bool check_nopybind(PyObject* value) override { // borrowed ref
// Do a type match first.
// NOLINTNEXTLINE(performance-no-int-to-ptr)
if (Py_TYPE(value) != (void*)_type_id) {
return false;
}
_PyRangeIterObject* iter = (_PyRangeIterObject*)value;
#if IS_PYTHON_3_12_PLUS
long start = iter->start;
#else
long start = iter->start + iter->index * iter->step;
#endif // IS_PYTHON_3_12_PLUS
long stop = iter->start + iter->len * iter->step;
return start == _start && stop == _stop && iter->step == _step;
}
private:
intptr_t _type_id;
// Normalized representation of a range iterator.
long _start;
long _stop;
long _step;
};
class TUPLE_ITERATOR_LEN : public LeafGuard {
public:
TUPLE_ITERATOR_LEN(
py::object length,
py::object type_id,
py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_length(py::cast<Py_ssize_t>(std::move(length))),
_type_id(py::cast<intptr_t>(std::move(type_id))) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
// Do a type match first.
// NOLINTNEXTLINE(performance-no-int-to-ptr)
if (Py_TYPE(value) != (void*)_type_id) {
return false;
}
_PyTupleIterObject* it = (_PyTupleIterObject*)value;
Py_ssize_t length = 0;
if (it->it_seq)
length = PyTuple_GET_SIZE(it->it_seq) - it->it_index;
return length == _length;
}
private:
// Length of the guarded list
Py_ssize_t _length;
intptr_t _type_id;
};
class LENGTH_CHECK : public LeafGuard {
public:
LENGTH_CHECK(py::object value, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_length(py::cast<Py_ssize_t>(std::move(value))) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
// PySequence_Length returns -1 if the object is not a sequence. So, we
// don't have to test for PySequence_Check.
return PySequence_Length(value) == _length;
}
private:
// Length of the guarded list
Py_ssize_t _length;
};
class DICT_LENGTH : public LeafGuard {
public:
DICT_LENGTH(py::object value, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_length(py::cast<Py_ssize_t>(std::move(value))) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
return PyDict_Check(value) && PyDict_Size(value) == _length;
}
private:
// Length of the guarded dict
Py_ssize_t _length;
};
class NOT_NONE : public LeafGuard {
public:
NOT_NONE(py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
return value != Py_None;
}
};
class DEFAULT_DEVICE : public LeafGuard {
public:
DEFAULT_DEVICE(py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)) {
py::handle device_module = py::module::import("torch.utils._device");
// Save the dict using py::object
_utils_device_dict = device_module.attr("__dict__");
_device = _utils_device_dict["CURRENT_DEVICE"];
}
bool check_nopybind(PyObject* value) override { // borrowed ref
// Create a static interned string. Interned string is faster than creating
// a new string every time. Even though its a new reference, we don't dec
// ref it. Interned strings are used for things like variable names and are
// leaked by design.
static PyObject* current_device_str =
PyUnicode_InternFromString("CURRENT_DEVICE");
PyObject* device = PyDict_GetItem(
_utils_device_dict.ptr(), current_device_str); // borrowed ref
if (device != _device.ptr()) {
int result = PyObject_RichCompareBool(device, _device.ptr(), Py_EQ);
if (result == -1) {
PyErr_Clear();
return false;
}
return result;
}
return true;
}
private:
// Save the current device and the module dict during the guard construction.
py::object _utils_device_dict;
py::object _device;
};
class GLOBAL_STATE : public LeafGuard {
public:
GLOBAL_STATE(py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)) {
_guard = std::make_unique<GlobalStateGuard>();
_guard->init();
}
bool check_nopybind(PyObject* value) override { // borrowed ref
// Ignore value arg, this is just to satisfy the interface.
return _guard->check();
}
GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
if (!_guard->check()) {
return GuardDebugInfo(
false, "GLOBAL_STATE changed: " + _guard->reason(), 0);
}
return GuardDebugInfo(true, 1);
}
private:
std::unique_ptr<GlobalStateGuard> _guard;
};
class DATA_PTR_MATCH : public LeafGuard {
public:
DATA_PTR_MATCH(py::object tensor, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)) {
PyObject* value = tensor.ptr();
if (!THPVariable_CheckExact(value) && !THPVariable_Check(value)) {
throw std::runtime_error("DATA_PTR_MATCH guard requires a tensor");
}
_data_ptr = THPVariable_Unpack(value).data_ptr();
}
bool check_nopybind(PyObject* value) override { // borrowed ref
if (!THPVariable_CheckExact(value) && !THPVariable_Check(value)) {
return false;
}
void* data_ptr = THPVariable_Unpack(value).data_ptr();
return data_ptr == _data_ptr;
}
private:
// Original tensor data pointer.
void* _data_ptr;
};
// Checks that an attr is absent in the object. We don't need the opposite
// HASATTR guard because we can just rely on GetAttrGuardAccessor to act as
// HASATTR guard.
class NO_HASATTR : public LeafGuard {
public:
NO_HASATTR(py::object attr_name, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_attr_name(std::move(attr_name)) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
return PyObject_HasAttr(value, _attr_name.ptr()) == 0;
}
private:
py::object _attr_name;
};
// Checks that dict contains or does not contain a key. This happens for
// PythonSysModulesVariable tracker.
// TODO(janimesh) - Check if we can use DictGuardManager. The downside could be
// large number of keys for sys module, so DICT_CONTAINS might still end up
// being faster.
class DICT_CONTAINS : public LeafGuard {
public:
DICT_CONTAINS(bool contains, py::object key, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_contains(contains ? 1 : 0),
_key(std::move(key)) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
int result = PyDict_Contains(value, _key.ptr());
if (result == -1) {
PyErr_Clear();
return false;
}
return result == _contains;
}
private:
int _contains;
py::object _key;
};
/**
* Relational guards compare more than one value. We implement Relational
* guards by capturing some state in the guard object. For example for tensor
* aliasing guards - tensor X is not tensor Y - we construct one leaf guard
* and and install it at as a leaf of two guard managers (one for X and
* another for Y). Therefore, this guard is run twice. In the first
* invocation, it saves the first value (state) and returns True. In the
* second invocation, it compares the saved value with the new value and
* returns True if they do not alias.
*
* We have to be careful about resetting in case the other guards fail and we
* have some state in the relational guard. This is done by virtual method
* reset_state(). This is called by the RootGuardManager before it exits.
*
*/
class RelationalGuard : public LeafGuard {
public:
RelationalGuard(py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)) {}
// reset the relational guard state on guard failure. This is called by the
// guard manager.
virtual void reset_state() = 0;
};
/**
* Checks that object x is object y.
*/
class OBJECT_ALIASING : public RelationalGuard {
public:
OBJECT_ALIASING(py::object verbose_code_parts)
: RelationalGuard(std::move(verbose_code_parts)) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
if (_is_first_call) {
_first_tensor = value;
_is_first_call = false;
return true;
}
return _first_tensor == value;
}
void reset_state() final {
_is_first_call = true;
}
private:
bool _is_first_call{true};
PyObject* _first_tensor{nullptr};
};
/**
* Checks that none of the tensors alias.
*/
class NO_TENSOR_ALIASING : public RelationalGuard {
public:
NO_TENSOR_ALIASING(
const py::list& tensor_names,
py::object verbose_code_parts)
: RelationalGuard(std::move(verbose_code_parts)),
_tensor_names(tensor_names) {
_unique_tensors.reserve(tensor_names.size());
}
bool check_nopybind(PyObject* value) override { // borrowed ref
// Typically we don't have to increment the ref count here because the
// tensors are held in f_locals. But there is a special case for
// `from_numpy` source. `from_numpy` converts integers and such into tensors
// and these tensors are ephemeral. If we don't incref, those tensors can be
// garbage collected, and the next time from_numpy can reuse the memory
// address. Therefore, we incref here. They are decref'd in reset_state.
Py_INCREF(value);
auto insertion = _unique_tensors.insert({value, nullptr});
if (!insertion.second) {
// No need to clear _unique_tensors, reset_state will do
// it.
return false;
}
return true;
}
GuardDebugInfo check_verbose_nopybind(PyObject* value) override {
bool result = check_nopybind(value);
if (!result) {
return GuardDebugInfo(
false, "Duplicate tensor found where not expected!", 0);
}
return GuardDebugInfo(true, 1);
}
void reset_state() final {
for (auto item : _unique_tensors) {
Py_DECREF(item.first);
}
_unique_tensors.clear();
}
private:
py::list _tensor_names;
ska::flat_hash_map<PyObject*, std::nullptr_t> _unique_tensors;
};
/**
* Checks the storage overlapping relation of input tensors.
*
* This guard is always installed in pairs: one for the possibly overlapping
* tensors, and another one for the non-overlapping tensors. This is so we can
* correctly identify the given tensor in the check method as one of the 2
* classes mentioned above.
*
* In the end, the one responsible for storing and checking is the
* `StorageOverlapChecker` class.
*/
class STORAGE_OVERLAPPING : public RelationalGuard {
public:
STORAGE_OVERLAPPING(
bool overlapping,
std::shared_ptr<StorageOverlapChecker> checker,
py::object verbose_code_parts)
: RelationalGuard(std::move(verbose_code_parts)),
_overlapping(overlapping),
_checker(checker) {}
bool check_nopybind(PyObject* value) override {
_checker->add(value, _overlapping);
return _checker->maybe_check();
}
void reset_state() final {
_checker->reset(_overlapping);
}
private:
// Flag that indicates which kind of tensor this guard is collecting:
// 1. Possibly overlapping tensors; or
// 2. Non-overlapping tensors.
bool _overlapping;
// Actual checker for this guard.
std::shared_ptr<StorageOverlapChecker> _checker;
};
class DYNAMIC_INDICES : public LeafGuard {
// C++ equivalent of
// code.append(
// f"(({tensor_name}._dynamo_dynamic_indices.issubset({value._dynamo_dynamic_indices}))
// if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" #
// noqa: B950
// )
public:
DYNAMIC_INDICES(py::set dynamic_indices, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)),
_dynamic_indices(std::move(dynamic_indices)) {}
bool check_nopybind(PyObject* value) override { // borrowed ref
// Make an interned string
static PyObject* dynamic_indices_str =
PyUnicode_InternFromString("_dynamo_dynamic_indices");
PyObject* indices = PyObject_GetAttr(value, dynamic_indices_str); // new ref
if (indices == nullptr) {
// Attr absent. Clear exception.
PyErr_Clear();
// This is true deliberately. If hasattr fails, we return true.
return true;
}
static PyObject* issubset_str = PyUnicode_InternFromString("issubset");
PyObject* call_result = PyObject_CallMethodObjArgs(
indices, issubset_str, _dynamic_indices.ptr(), nullptr); // new ref
bool result = PyObject_IsTrue(call_result);
Py_DECREF(call_result);
Py_DECREF(indices);
return result;
}
private:
py::set _dynamic_indices;
};
class DICT_VERSION : public LeafGuard {
public:
DICT_VERSION(py::object value, py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)) {
if (!PyDict_Check(value.ptr())) {
throw py::type_error("DICT_VERSION expects a dict");
}
_tag = get_dict_version_unchecked(value.ptr());
}
bool check_nopybind(PyObject* value) override { // borrowed ref
return PyDict_Check(value) && get_dict_version_unchecked(value) == _tag;
}
// Saved dict version.
uint64_t _tag;
};
// GuardManager can be a pointer to DictGuardManager, but at this point the
// compiler does not know that DictGuardManager is a derived class of
// GuardManager (no way to define inheritance relationships in forward
// declarations), so we forward declare a factory function and define it when
// both DictGuardManager and GuardManager are fully defined.
std::unique_ptr<GuardManager> make_guard_manager(
RootGuardManager* root,
std::string source,
py::handle example_value,
py::handle guard_manager_enum);
GuardManager* clone_guard_manager(
GuardManager* from,
RootGuardManager* root,
const py::function& clone_filter_fn);
void add_relational_guard_resetter_to_cloned_root(
RootGuardManager* root,
std::shared_ptr<RelationalGuard> guard);
/**
* Base class representing a pair of accessor and the associated guard
* manager. The accessor defines how to access the child value from the
* py::object given to the parent check function.
*
* GuardAccessors can be considered equivalent to name() method of Source
* objects in guards.py. In python, name() method returns a str which we can
* then eval in f_locals and f_globals to retrieve the actual py object.
* GuardAccessor serves the same purpose. The minor difference is that
* GuardManager is a tree structure, so a GuardAccessor just has to retrieve
* the value in the next level in this tree and pass it to the child
* GuardAccessor.
*
* GuardAccessor also owns the GuardManager associated with the retrieved
* value from the GuardAccessor.
*/
class GuardAccessor {
public:
GuardAccessor(
RootGuardManager* root,
py::object accessor_key,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: _guard_manager(make_guard_manager(
root,
source,
example_value,
guard_manager_enum)),
_accessor_key(std::move(accessor_key)),
_source(std::move(source)) {}
// Return by reference as GuardAccessor owns the GuardManager.
std::unique_ptr<GuardManager>& get_guard_manager() {
return _guard_manager;
}
bool matches_key(const py::handle& key) const {
return _accessor_key.equal(key);
}
std::string get_source() {
return _source;
}
// matches_dict_tag is used by the DictGetItemGuardAccessor to skip the guard
// subtree on immutable dict getitems.
virtual bool check_nopybind(PyObject* obj, bool matches_dict_tag = false) = 0;
virtual GuardDebugInfo check_verbose_nopybind(PyObject* obj) = 0;
virtual std::string repr() const = 0;
virtual ~GuardAccessor() = default;
public: // Cloning related functions
GuardAccessor(GuardManager* guard_manager, GuardAccessor* from)
: _guard_manager(std::unique_ptr<GuardManager>(guard_manager)) {
from->clone_visitor(this);
}
virtual GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) = 0;
void clone_visitor(GuardAccessor* to) {
to->_source = this->_source;
to->_accessor_key = this->_accessor_key;
}
template <typename DerivedGuardAccessor>
GuardAccessor* clone_common(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) {
GuardManager* cloned_mgr = clone_guard_manager(
get_guard_manager().get(), cloned_root, clone_filter_fn);
if (cloned_mgr == nullptr) {
return nullptr;
}
DerivedGuardAccessor* cloned_accessor =
new DerivedGuardAccessor(cloned_mgr, (DerivedGuardAccessor*)this);
return cloned_accessor;
}
protected:
// Guard manager corresponding to the retrieved value from the
// GuardAccessor.
std::unique_ptr<GuardManager> _guard_manager;
// accessor key could be py::str for getattr, getitem or py::function for
// lambda accessor. It is a py::object because we need to keep these accessor
// keys alive.
py::object _accessor_key;
// A string that can be eval'd on f_locals or f_globals to access the variable
// value. Only used for debugging.
std::string _source;
};
/**
* GuardManager encapsulates all the guards related to a particular
* py::object. It is a tree structure and consists of 1) Leaf guards - Guards
* that are run on the user given object 2) Accessors - Guard accessors (like
* getattr, getitem) to access the next value in the tree hierarchy. Accessor
* object also holds the child GuardManager.
*
* Lets look at an example to understand how it works.
* class Pair:
* int x = 1;
* int y = 2;
*
* At compile time
* >> guard_mananger = GuardManager()
* >> guard_mananger.x.add_lambda_guard(
* lambda x: isinstance(x, Pair),
* lambda x: f"expected Pair, found {type(x)}"
* )
* >> guard_mananger.x.add_lambda_guard(lambda x: x == 1, lambda x: f"found
* {x}, expected 1")
* >> guard_mananger.y.add_lambda_guard(lambda x: x == 2, lambda x: f"found
* {x}, expected 2")
*
* At runtime
* >> guard_mananger.check(Pair())
*
* At compile time we build the tree structure. When we do `guard_manager.x`,
* it creates an AttrGuardAccessorNode, initializes a child guard manager with
* this accessor node, and adds it as a child. When we do
* `guard_manager.x.add_lambda_guard`, we call add_lambda_guard on the newly
* created guard manager and register a new leaf guard on it.
*
* At runtime, the accessor node has an important function of providing a way
* to access the value for the child guard. In the above example,
* guard_manager.x adds an AttrGuardAccessorNode with attr_name x. When check
* function is called, parent GuardManager calls getattr(value, "x") on its
* value passed to the check function to call the check function of the child
* guard manager.
*
* Performace optimization for fail fast - An optimization for runtime here is
* to sort the execution of child guards depending on the failure count. This
* ensures that we run the guards that are more prone to fail statistically
* first. This can improve the cache lookup time when we have multiple cache
* entries.
*/
// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions)
class GuardManager {
public:
GuardManager() = delete;
GuardManager(RootGuardManager* root, std::string source)
: _root(root), _source(std::move(source)), _is_dict(false) {}
GuardManager(
RootGuardManager* root,
std::string source,
py::handle example_value)
: _root(root),
_source(std::move(source)),
_is_dict(py::isinstance<py::dict>(example_value)) {
if (_is_dict) {
_dict_tag = get_dict_version_unchecked(example_value.ptr());
}
}
GuardManager(const GuardManager& m) = delete;
GuardManager& operator=(const GuardManager&) = delete;
virtual ~GuardManager() = default;
RootGuardManager* get_root() {
return _root;
}
std::string get_source() {
return _source;
}
virtual void add_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard) {
_leaf_guards.emplace_back(std::move(leaf_guard));
}
public:
// For cloning
GuardManager(RootGuardManager* root, std::string source, bool is_dict)
: _root(root), _source(std::move(source)), _is_dict(is_dict) {}
void clone_common(
RootGuardManager* cloned_root,
GuardManager* cloned_mgr,
const py::function& clone_filter_fn) {
for (const auto& guard : _leaf_guards) {
cloned_mgr->_leaf_guards.emplace_back(guard);
if (std::shared_ptr<RelationalGuard> relational_guard =
std::dynamic_pointer_cast<RelationalGuard>(guard)) {
add_relational_guard_resetter_to_cloned_root(
cloned_root, relational_guard);
}
}
for (const auto& accessor : _accessors) {
GuardAccessor* cloned_accessor =
accessor->clone(cloned_root, clone_filter_fn);
if (cloned_accessor != nullptr) {
cloned_mgr->_accessors.emplace_back(
std::unique_ptr<GuardAccessor>(cloned_accessor));
}
}
}
virtual GuardManager* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) {
if (!py::cast<bool>(clone_filter_fn(this))) {
return nullptr;
}
GuardManager* cloned_mgr = new GuardManager(cloned_root, _source, _is_dict);
clone_common(cloned_root, cloned_mgr, clone_filter_fn);
return cloned_mgr;
}
/**
* Adds a new guard manager with appropriate Accessor. If the accessor is
* already present, we just return the guard manager.
*/
template <typename GuardAccessorT>
GuardManager* get_child_manager(
py::object accessor_key,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) {
// accessor_key type depends on the GuardAccessorT
// for example for GetAttrGuardAccessor - py::str name
// Return the manager if the guard accessor exists
for (const auto& accessor : _accessors) {
if (accessor->matches_key(accessor_key)) {
return accessor->get_guard_manager().get();
}
}
// Construct a new guard accessor
_accessors.emplace_back(std::make_unique<GuardAccessorT>(
_root,
std::move(accessor_key),
source,
example_value,
guard_manager_enum));
return _accessors.back()->get_guard_manager().get();
}
// Runs the leaf guards check and then child managers check function.
//
// NB: There is some code DUPLICATION between this and check_verbose
// function. This is intentional. check function is in the hot path and is
// kept very simple. The purpose of check_verbose function is to get guard
// failure reasoning to understand recompilations. check_verbose function
// does not change the state of the guard, e.g., it does not shuffle the
// guards and does not change the fail count. For simplicity, we duplicate
// the code here.
virtual bool check_nopybind(PyObject* value) { // borrowed ref
if (!this->check_leaf_guards_nopybind(value)) {
return false;
}
return this->check_accessors_nopybind(value);
}
bool check_leaf_guards_nopybind(PyObject* value) {
// Iterate over leaf guards
for (const auto& guard : _leaf_guards) {
if (!guard->check_nopybind(value)) { // early exit
_fail_count += 1;
// no need of sorting, just return.
return false;
}
}
return true;
}
bool check_accessors_nopybind(PyObject* value) {
bool matches_dict_tag = false;
uint64_t new_tag = 0;
if (_is_dict) {
// Check if the dict tag matches. If it does, propagate to the child
// accessors. This will pass to the child manager via
// DictGetItemGuardManager.
new_tag = get_dict_version_unchecked(value);
matches_dict_tag = new_tag == _dict_tag;
}
// Iterate over accessors.
bool result = true;
bool failed_on_first = true;
for (const auto& accessor : _accessors) {
if (!accessor->check_nopybind(value, matches_dict_tag)) { // early exit
_fail_count += 1;
result = false;
// need to sort, so break the loop.
break;
}
failed_on_first = false;
}
// failed_on_first is just an optimization to avoid sorting if we are
// failing on the first accessor itself. This is helpful when we have
// already sorted the guards once, and dont need to sort again.
if (!result && !failed_on_first) {
// Inplace sort the child guards by fail count. This moves the guard
// with higher fail count earlier in the queue, and enables fail fast
// for the next check_verbose.
// An alternate implementation was to use priority queue directly on
// _accessors, but it was rejected because of the complexity of
// popping and creating a new pq on each run_guards. Moreover, this sort
// is happening on the unhappy path when check_verbose guard
// fails. So, its probably ok.
std::sort(
_accessors.begin(),
_accessors.end(),
[](const std::unique_ptr<GuardAccessor>& a,
const std::unique_ptr<GuardAccessor>& b) {
return a->get_guard_manager()->fail_count() >
b->get_guard_manager()->fail_count();
});
}
if (_is_dict && result) {
// If result is true, reset the _dict_tag. This is useful if there is a
// mutation on the dict but it does not change the attr values (like
// swapping).
_dict_tag = new_tag;
}
return result;
}
// This function has some code duplication with function check. This is
// deliberate to keep check function simple and fast.
virtual GuardDebugInfo check_verbose_nopybind(
PyObject* value) { // borrowed ref
int num_guards_executed = 0;
const GuardDebugInfo& debug_info =
check_leaf_guards_verbose_nopybind(value, num_guards_executed);
if (!debug_info.result) {
return debug_info;
}
return check_accessors_verbose_nopybind(value, num_guards_executed);
}
GuardDebugInfo check_leaf_guards_verbose_nopybind(
PyObject* value,
int& num_guards_executed) {
// Iterate over leaf guards
for (const auto& guard : _leaf_guards) {
const GuardDebugInfo& debug_info = guard->check_verbose_nopybind(value);
num_guards_executed++;
if (!debug_info.result) {
return GuardDebugInfo(
false, debug_info.verbose_code_parts, num_guards_executed);
}
}
return GuardDebugInfo(true, num_guards_executed);
}
GuardDebugInfo check_accessors_verbose_nopybind(
PyObject* value,
int& num_guards_executed) {
// Iterate over accessors
for (const auto& accessor : _accessors) {
const GuardDebugInfo& debug_info =
accessor->check_verbose_nopybind(value);
num_guards_executed += debug_info.num_guards_executed;
if (!debug_info.result) {
return GuardDebugInfo(
false, debug_info.verbose_code_parts, num_guards_executed);
}
}
return GuardDebugInfo(true, num_guards_executed);
}
int64_t fail_count() const {
return _fail_count;
}
// DEBUG function - Returning raw pointers because we can't return unique_ptr
// and pybind does not accept a unique_ptr reference return type.
virtual std::vector<GuardAccessor*> get_accessors() const {
std::vector<GuardAccessor*> ret;
ret.reserve(_accessors.size());
for (const auto& accessor : _accessors) {
ret.emplace_back(accessor.get());
}
return ret;
}
// DEBUG function - Returning raw pointers because we can't return unique_ptr
// and pybind does not accept a unique_ptr reference return type.
virtual std::vector<GuardManager*> get_child_managers() {
std::vector<GuardManager*> ret;
ret.reserve(_accessors.size());
for (const auto& accessor : _accessors) {
ret.emplace_back(accessor->get_guard_manager().get());
}
return ret;
}
// DEBUG function - Returning raw pointers because we can't return unique_ptr
// and pybind does not accept a unique_ptr reference return type.
std::vector<LeafGuard*> get_leaf_guards() const {
std::vector<LeafGuard*> ret;
ret.reserve(_leaf_guards.size());
for (const auto& guard : _leaf_guards) {
ret.push_back(guard.get());
}
return ret;
}
bool is_leaf_guard_present(const std::string& guard_name) {
return _inserted_leaf_guards.find(guard_name) !=
_inserted_leaf_guards.end();
}
void insert_leaf_guard(const std::string& guard_name) {
_inserted_leaf_guards.insert(guard_name);
}
void add_permitted_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard) {
// Selectively called for permitted guards. This is used by DictGuardManager
// which overrides the add_leaf_guard manager to throw runtime error.
GuardManager::add_leaf_guard(std::move(leaf_guard));
}
protected:
// Keeps a count of how many times this guard manager check function returns
// False. This is used for sorting optimization.
int64_t _fail_count{0};
private:
// Root of the guard manager, this is the used to install the relational
// guard resetters.
RootGuardManager* _root;
// A string that can be used to eval on f_locals or f_globals to get the
// value. This is used only to pass on debugging information.
std::string _source;
// A map of which leaf guards are inserted. This is to prevent duplicate
// guards like TYPE_MATCH.
std::unordered_set<std::string> _inserted_leaf_guards;
// Leaf guards are the terminal guards on this object, e.g, type check on a
// list. These guards have to be run before any children are run.
//
// These leaf guards are not shufflable. In almost all cases, these guards
// will have an order, e,g., type(x) is int guard and x == 5 guard. We also
// expect very few leaf guards per GuardManager node.
//
// NB: Why are leaf guards shared ptr? This is primarily to enable relational
// guards like `tensor X is not tensor Y`. These guards require multiple
// values. We handle it by creating one guard object that holds state and this
// guard is installed in many guard managers, hence a shared ptr.
std::vector<std::shared_ptr<LeafGuard>> _leaf_guards;
// GuardAccessors nodes to access the child guards. These guards are
// shufflable. On a guard failure, they are sorted based on their fail count
// to enable fail fast for the next check.
std::vector<std::unique_ptr<GuardAccessor>> _accessors;
bool _is_dict;
uint64_t _dict_tag{0};
};
/**
Note on [Ownership with cloning] - GuardManagers have the facility to clone
itself. This is useful for cloning a subset of the guard manager in diff guard
manager.
As the ownership goes, the model is exactly same as before. We have unique_ptr
for GuardAccessor and GuardManagers. So, any state required for the accessors
and managers is copied over using constructors and clone_visitor functions.
The main thing to notice is leaf guards. The leaf guards are represented using
shared_ptr, and they are shared (not cloned) with the cloned managers.
So for leaf guard state to be released, both the original and cloned managers
have to be destructed.
*/
/**
* RootGuardManager is the root of the guard tree. This is primarily
* constructed to hold the relational guard pointers so that we can reset the
* state of those guards on guard failure. All the other important
* implementation is in GuardManager class.
*/
class RootGuardManager : public GuardManager {
public:
// This is the root node, set its _root member to nullptr
RootGuardManager() : GuardManager(this, "L") {}
// Adds the relational guard resetter
void add_relational_guard_resetter(
std::shared_ptr<RelationalGuard> relational_guard) {
_relational_guard_resetters.emplace_back(std::move(relational_guard));
}
// Python visible API to check guard function.
bool check(py::handle value) {
return check_nopybind(value.ptr());
}
// Python visible API to check_verbose guard function.
GuardDebugInfo check_verbose(py::handle value) {
return check_verbose_nopybind(value.ptr());
}
// Fast check function.
bool check_nopybind(PyObject* value) override { // borrowed ref
// Check [Note on GIL interaction with mutex lock] for details on why we
// need mutex and its interactions wth GIL.
PyThreadState* _save = nullptr;
Py_UNBLOCK_THREADS; // ; is added to avoid clang-formatting
std::lock_guard<std::mutex> lock_guard(_lock);
Py_BLOCK_THREADS; // ; is added to avoid clang-formatting
// Get the local state. This will be used for TENSOR_MATCH guards.
if (_init_local_state) {
LocalState state;
_local_state = state;
}
if (!GuardManager::check_leaf_guards_nopybind(value)) {
_reset_relational_guard_state();
return false;
}
// Run accessor guards without TorchFunction enabled
// Dynamo should only be adding guards on values without
// torch function at this point, because if there
// was a torch function, we should've traced through it
const at::impl::TorchFunctionDisabledState old_state =
at::impl::PythonTorchFunctionTLS::get_disabled_state();
at::impl::PythonTorchFunctionTLS::set_disabled_state(
at::impl::TorchFunctionDisabledState::ALL_DISABLED);
if (!GuardManager::check_accessors_nopybind(value)) {
at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state);
_reset_relational_guard_state();
return false;
}
// Iterate over epilogue leaf guards.
for (const auto& guard : _epilogue_lambda_guards) {
if (!guard->check_nopybind(value)) { // early exit
at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state);
_reset_relational_guard_state();
return false;
}
}
at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state);
_reset_relational_guard_state();
return true;
}
// Fast check_verbose function.
GuardDebugInfo check_verbose_nopybind(
PyObject* value) override { // borrowed ref
// Check [Note on GIL interaction with mutex lock] for details on why we
// need mutex and its interactions wth GIL.
PyThreadState* _save = nullptr;
Py_UNBLOCK_THREADS; // ; is added to avoid clang-formatting
std::lock_guard<std::mutex> lock_guard(_lock);
Py_BLOCK_THREADS; // ; is added to avoid clang-formatting
// Get the local state. This will be used for TENSOR_MATCH guards.
if (_init_local_state) {
LocalState state;
_local_state = state;
}
int num_guards_executed = 0;
// Run leaf guards
// This includes the GlobalStateGuard and the Torch Function Mode stack
// guard, which require Torch Function to be in its unmodified state
const GuardDebugInfo& debug_info_leaf =
GuardManager::check_leaf_guards_verbose_nopybind(
value, num_guards_executed);
if (!debug_info_leaf.result) {
_reset_relational_guard_state();
return debug_info_leaf;
}
const at::impl::TorchFunctionDisabledState old_state =
at::impl::PythonTorchFunctionTLS::get_disabled_state();
at::impl::PythonTorchFunctionTLS::set_disabled_state(
at::impl::TorchFunctionDisabledState::ALL_DISABLED);
const GuardDebugInfo& debug_info_accessors =
GuardManager::check_accessors_verbose_nopybind(
value, num_guards_executed);
if (!debug_info_accessors.result) {
at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state);
_reset_relational_guard_state();
return debug_info_accessors;
}
// Iterate over epilogue leaf guards
for (const auto& guard : _epilogue_lambda_guards) {
const GuardDebugInfo& tmp_debug_info =
guard->check_verbose_nopybind(value);
num_guards_executed++;
if (!tmp_debug_info.result) {
at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state);
_reset_relational_guard_state();
return GuardDebugInfo(
false, tmp_debug_info.verbose_code_parts, num_guards_executed);
}
}
at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state);
_reset_relational_guard_state();
return GuardDebugInfo(true, num_guards_executed);
}
void add_epilogue_lambda_guard(std::unique_ptr<LeafGuard> leaf_guard) {
_epilogue_lambda_guards.emplace_back(std::move(leaf_guard));
}
void set_init_local_state_flag() {
_init_local_state = true;
}
// See note on [Ownership with cloning]
RootGuardManager* clone_manager(const py::function& clone_filter_fn) {
// Use clone_filter_fn
if (!py::cast<bool>(clone_filter_fn(this))) {
return nullptr;
}
RootGuardManager* cloned_root = new RootGuardManager();
clone_common(cloned_root, cloned_root, clone_filter_fn);
for (const auto& guard : _epilogue_lambda_guards) {
cloned_root->_epilogue_lambda_guards.emplace_back(guard);
}
return cloned_root;
}
// DEBUG function - Returning raw pointers because we can't return unique_ptr
// and pybind does not accept a unique_ptr reference return type.
std::vector<LeafGuard*> get_epilogue_lambda_guards() const {
std::vector<LeafGuard*> ret;
ret.reserve(_epilogue_lambda_guards.size());
for (const auto& guard : _epilogue_lambda_guards) {
ret.push_back(guard.get());
}
return ret;
}
private:
// Reset the state of all the relational guards on failure.
void _reset_relational_guard_state() {
for (auto& guard : _relational_guard_resetters) {
guard->reset_state();
}
}
public:
// Local state for TENSOR_MATCH guards.
LocalState _local_state;
private:
// All the relational guards under this guard mananger. We only use these
// when the guard evaluates to False. This ensures that guard state is reset
// on guard failure so that next invocation is clean.
std::vector<std::shared_ptr<RelationalGuard>> _relational_guard_resetters;
// These guards are lambda guards, i.e., the guards that lack C++
// implementation. For simplicity, we add these guards at the root. They
// MUST be run after all other guard managers have finished to ensure that
// the epilogue guards do not step on some nonexistent getattr or getitem.
// NB - shared_ptr is used to share the epilogue guards with the cloned guard
// manager.
std::vector<std::shared_ptr<LeafGuard>> _epilogue_lambda_guards;
// [Note on GIL interaction with mutex lock]
// We use std::mutex to prevent multiple threads from running
// check/check_verbose simultaneously. This is to prevent race condition due
// to state changes in RelationalGuard.
//
// However, we also need to be careful about GIL interaction with mutex. There
// is a chance of deadlock
//
// Thread 1: has GIL, waiting for lock
// Thread 2: has lock, waiting for GIL
//
// This can happen when Thread 2 earlier acquired the mutex lock, starting
// running the critical section of check function and then called some python
// function (like LAMBDA_GUARD) and reached Cpython codebase that checks if it
// should release the GIL (typically happens after every few bytecode
// instructions). Thread 2 here can decide to release the GIL. Thread 1 can
// acquire GIL and reach the mutex, where it will wait forever.
//
// To avoid this, each thread releases the GIL before acquiring the mutex and
// then acquires the GIL again after acquiring the mutex lock by using
// Py_BLOCK_THREADS and Py_UNBLOCK_THREADS. This avoids the deadlock.
std::mutex _lock;
// We init LocalState only when this flag it set. This flag is set during
// TENSOR_MATCH guard init.
bool _init_local_state = false;
};
/*
* Dicts are common in python code. Therefore, we handle guards for dicts
* differently and use PyDict_* APIs which are faster than PyObject_* APIs
* because of no ref count increments/decrements.
*
* DictGuardManager relies on the order of dict.keys(). It keeps track of the
* indices of dict.keys() to access the key, value pair.
*/
typedef std::pair<std::unique_ptr<GuardManager>, std::unique_ptr<GuardManager>>
KeyValueManager;
class DictGuardManager : public GuardManager {
public:
DictGuardManager(
RootGuardManager* root,
std::string source,
py::handle example_value)
: GuardManager(root, std::move(source)),
_size(PyDict_Size(example_value.ptr())),
_expected_type(Py_TYPE(example_value.ptr())),
_is_exact_dict_type(PyDict_CheckExact(example_value.ptr())) {}
GuardManager* get_key_manager(
py::object key_index,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) {
KeyValueManager& key_value_manager =
_get_index_manager(std::move(key_index));
if (!key_value_manager.first) {
key_value_manager.first = make_guard_manager(
this->get_root(),
std::move(source),
example_value,
guard_manager_enum);
};
return key_value_manager.first.get();
}
GuardManager* get_value_manager(
py::object key_index,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) {
KeyValueManager& key_value_manager =
_get_index_manager(std::move(key_index));
if (!key_value_manager.second) {
key_value_manager.second = make_guard_manager(
this->get_root(),
std::move(source),
example_value,
guard_manager_enum);
};
return key_value_manager.second.get();
}
bool check_nopybind(PyObject* obj) override { // borrowed ref
// TODO(janimesh) - Implement a fast-path using dict versions.
if (Py_TYPE(obj) != _expected_type) {
_fail_count += 1;
return false;
}
if (PyDict_Size(obj) != _size) {
_fail_count += 1;
return false;
}
// Early return
if (_size == 0) {
return true;
}
// Invokes the base class's check_nopybind method. We permit a limited set
// of leaf guards and accessors within the DictGuardManager framework.
// Integrating certain guards or accessors directly within the
// DictGuardManager can be challenging. For instance, `type(dict_object)` as
// an accessor is permissible, which otherwise would be hard to integrate
// directly into DictGuardManager. Similarly, incorporating guards such as
// DICT_CONTAINS and DICT_VERSION as leaf guards offers a simpler solution
// than embedding these functionalities within the DictGuardManager itself.
if (!GuardManager::check_nopybind(obj)) {
_fail_count += 1;
// No need to shuffle the child guards, just return.
return false;
}
PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0;
// Points to an element in the _indices vector.
size_t index_pointer = 0;
// Points to the key index in the dict
Py_ssize_t dict_pointer = 0;
while (index_pointer < _indices.size() &&
PyDict_Next(obj, &pos, &key, &value)) {
// Skip if dict_pointer is not a saved index.
if (dict_pointer == _indices[index_pointer]) {
index_pointer += 1;
KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
if (key_manager && !key_manager->check_nopybind(key)) {
return false;
}
std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
if (value_manager && !value_manager->check_nopybind(value)) {
return false;
}
}
dict_pointer += 1;
}
return true;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
if (Py_TYPE(obj) != _expected_type) {
return GuardDebugInfo(false, "TYPE_MISMATCH(" + get_source() + ")", 0);
}
if (PyDict_Size(obj) != _size) {
return GuardDebugInfo(
false, "len(" + get_source() + ") != " + std::to_string(_size), 0);
}
// Early return
if (_size == 0) {
return GuardDebugInfo(true, 0);
}
// Invokes the base class's check_nopybind method. We permit a limited set
// of leaf guards and accessors within the DictGuardManager framework.
// Integrating certain guards or accessors directly within the
// DictGuardManager can be challenging. For instance, `type(dict_object)` as
// an accessor is permissible, which otherwise would be hard to integrate
// directly into DictGuardManager. Similarly, incorporating guards such as
// DICT_CONTAINS and DICT_VERSION as leaf guards offers a simpler solution
// than embedding these functionalities within the DictGuardManager itself.
GuardDebugInfo debug_info = GuardManager::check_verbose_nopybind(obj);
if (!debug_info.result) {
return debug_info;
}
PyObject *key = nullptr, *value = nullptr;
Py_ssize_t pos = 0;
// Points to an element in the _indices vector.
size_t index_pointer = 0;
Py_ssize_t dict_pointer = 0;
int num_guards_executed = 0;
while (index_pointer < _indices.size() &&
PyDict_Next(obj, &pos, &key, &value)) {
// Skip if pos is not a saved index.
if (dict_pointer == _indices[index_pointer]) {
index_pointer += 1;
KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
if (key_manager) {
GuardDebugInfo debug_info = key_manager->check_verbose_nopybind(key);
num_guards_executed += debug_info.num_guards_executed;
if (!debug_info.result) {
return GuardDebugInfo(
false, debug_info.verbose_code_parts, num_guards_executed);
}
}
std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
if (value_manager) {
GuardDebugInfo debug_info =
value_manager->check_verbose_nopybind(value);
num_guards_executed += debug_info.num_guards_executed;
if (!debug_info.result) {
return GuardDebugInfo(
false, debug_info.verbose_code_parts, num_guards_executed);
}
}
}
dict_pointer += 1;
}
return GuardDebugInfo(true, num_guards_executed);
}
void skip_adding_guard(const py::object& a, const py::object& b) {
// The `add_leaf_guard` method in `DictGuardManager` is overridden to block
// the addition of leaf guards. However, this is too strict. Python side of
// guard management frequently adds TYPE_MATCH and DICT_LENGTH on
// DictGuardManager. We could refactor Python side to never call these
// guards on dict objects, but that results in messy code. Instead, we just
// override these two guards to not go through add_leaf_guard code path and
// skip adding guards. This makes the python side easy.
}
void fail_on_get_child_manager(
const py::object& a,
const std::string& source,
const py::object& b) {
throw std::runtime_error("Can not add an accessor to DictGuardManager");
}
void add_leaf_guard(std::shared_ptr<LeafGuard> leaf_guard) override {
// If you are calling this, you probably want to go through a key, value
// child manager and then add a leaf guard on them. DictGuardManager already
// has TYPE_MATCH and LENGTH_CHECK built in.
throw std::runtime_error("DictGuardManager does not support a leaf_guard");
}
// Debug helper - Returning raw pointers because we can't return unique_ptr
// and pybind does not accept a unique_ptr reference return type.
std::unordered_map<Py_ssize_t, std::pair<GuardManager*, GuardManager*>>
get_key_value_managers() {
std::unordered_map<Py_ssize_t, std::pair<GuardManager*, GuardManager*>> ret;
for (auto index : _indices) {
ret[index] = std::make_pair(
_key_value_managers[index].first.get(),
_key_value_managers[index].second.get());
}
return ret;
}
bool is_exact_dict_type() {
return _is_exact_dict_type;
}
public: // cloning functions
DictGuardManager(
RootGuardManager* cloned_root,
std::string source,
Py_ssize_t size,
PyTypeObject* expected_type,
bool is_exact_dict_type,
std::vector<Py_ssize_t> indices)
: GuardManager(cloned_root, std::move(source), true),
_size(size),
_expected_type(expected_type),
_is_exact_dict_type(is_exact_dict_type),
_indices(std::move(indices)) {}
template <typename T>
GuardManager* clone_dict_guard_manager(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) {
if (!py::cast<bool>(clone_filter_fn(this))) {
return nullptr;
}
T* cloned_mgr = new T(
cloned_root,
get_source(),
_size,
_expected_type,
_is_exact_dict_type,
_indices);
clone_common(cloned_root, cloned_mgr, clone_filter_fn);
for (auto index : _indices) {
KeyValueManager& key_value_manager = _key_value_managers[index];
std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
cloned_mgr->_key_value_managers[index] = std::make_pair(nullptr, nullptr);
if (key_manager) {
GuardManager* cloned_key_manager =
key_manager->clone(cloned_root, clone_filter_fn);
if (cloned_key_manager) {
cloned_mgr->_key_value_managers[index].first =
std::unique_ptr<GuardManager>(cloned_key_manager);
}
}
if (value_manager) {
GuardManager* cloned_value_manager =
value_manager->clone(cloned_root, clone_filter_fn);
if (cloned_value_manager) {
cloned_mgr->_key_value_managers[index].second =
std::unique_ptr<GuardManager>(cloned_value_manager);
}
}
}
return cloned_mgr;
}
GuardManager* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_dict_guard_manager<DictGuardManager>(
cloned_root, clone_filter_fn);
}
private:
/**
* Adds a new KeyDictGuardAccessor. If the accessor is already present, we
* just return the guard manager.
*/
KeyValueManager& _get_index_manager(py::object key_index) {
// Check if the accessor is already present.
Py_ssize_t index = py::cast<Py_ssize_t>(std::move(key_index));
auto it = _key_value_managers.find(index);
if (it != _key_value_managers.end()) {
return it->second;
}
_indices.push_back(index);
// Always keep the _indices array sorted
std::sort(_indices.begin(), _indices.end());
_key_value_managers[index] = std::make_pair(nullptr, nullptr);
return _key_value_managers[index];
}
protected: // also used by DictSubclassGuardManager
Py_ssize_t _size;
// DictGuardManager supports both exact dict type and non-exact dict type.
// Therefore, we have to compare the type to early exit.
PyTypeObject* _expected_type;
bool _is_exact_dict_type; // Useful to check getattr_manager validity.
std::vector<Py_ssize_t> _indices;
std::unordered_map<Py_ssize_t, KeyValueManager> _key_value_managers;
};
/**
* The DictSubclassGuardManager is designed to work with dict subclasses,
* specifically focusing on OrderedDicts. Standard dictionaries leverage the
* PyDict_Next function to iterate over keys, values, and items. OrderedDicts,
* on the other hand, rely on an additional linked list structure to maintain
* keys order. Although PyDict_Next and OrderedDict generally yield the same
* order, discrepancies arise when using OrderedDict's move_to_end method (used
* in Pytorch hooks). `move_to_end` method only updates the linked list, leaving
* PyDict_Next unaffected. Therefore, to accurately capture key ordering in such
* cases, DictSubclassGuardManager directly invoke the .keys() method.
*/
class DictSubclassGuardManager : public DictGuardManager {
public:
DictSubclassGuardManager(
RootGuardManager* root,
std::string source,
py::handle example_value)
: DictGuardManager(root, std::move(source), example_value) {}
public:
bool check_nopybind(PyObject* obj) override { // borrowed ref
// TODO(janimesh) - Implement a fast-path using dict versions.
if (Py_TYPE(obj) != _expected_type) {
_fail_count += 1;
return false;
}
if (PyDict_Size(obj) != _size) {
_fail_count += 1;
return false;
}
// Early return
if (_size == 0) {
return true;
}
if (!GuardManager::check_nopybind(obj)) { // NOLINT
_fail_count += 1;
// No need to shuffle the child guards, just return.
return false;
}
// Points to an element in the _indices vector.
size_t index_pointer = 0;
// Points to the key index in the dict
Py_ssize_t dict_pointer = 0;
// Use iter(dict.keys()) to iterate over the keys
py::object keys =
py::handle(obj).attr("keys")(); // py::object handles the references
PyObject* iterator = PyObject_GetIter(keys.ptr()); // new reference
PyObject* key = nullptr;
while (index_pointer < _indices.size() &&
(key = PyIter_Next(iterator))) { // new reference
if (dict_pointer == _indices[index_pointer]) {
KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
if (key_manager && !key_manager->check_nopybind(key)) {
Py_DECREF(key);
Py_DECREF(iterator);
return false;
}
PyObject* value = PyDict_GetItem(obj, key); // borrowed ref
std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
if (value_manager && !value_manager->check_nopybind(value)) {
Py_DECREF(key);
Py_DECREF(iterator);
return false;
}
index_pointer++;
}
dict_pointer++;
Py_DECREF(key);
}
Py_DECREF(iterator);
return true;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
if (Py_TYPE(obj) != _expected_type) {
return GuardDebugInfo(false, "TYPE_MISMATCH(" + get_source() + ")", 0);
}
if (PyDict_Size(obj) != _size) {
return GuardDebugInfo(
false, "len(" + get_source() + ") != " + std::to_string(_size), 0);
}
// Early return
if (_size == 0) {
return GuardDebugInfo(true, 0);
}
GuardDebugInfo debug_info =
GuardManager::check_verbose_nopybind(obj); // NOLINT
if (!debug_info.result) {
return debug_info;
}
// Points to an element in the _indices vector.
size_t index_pointer = 0;
// Points to the key index in the dict
Py_ssize_t dict_pointer = 0;
int num_guards_executed = 0;
// Use iter(dict.keys()) to iterate over the keys
py::object keys =
py::handle(obj).attr("keys")(); // py::object handles the references
PyObject* iterator = PyObject_GetIter(keys.ptr()); // new reference
PyObject* key = nullptr;
while (index_pointer < _indices.size() &&
(key = PyIter_Next(iterator))) { // new reference
if (dict_pointer == _indices[index_pointer]) {
KeyValueManager& key_value_manager = _key_value_managers[dict_pointer];
std::unique_ptr<GuardManager>& key_manager = key_value_manager.first;
if (key_manager) {
GuardDebugInfo debug_info = key_manager->check_verbose_nopybind(key);
num_guards_executed += debug_info.num_guards_executed;
if (!debug_info.result) {
Py_DECREF(key);
Py_DECREF(iterator);
return GuardDebugInfo(
false, debug_info.verbose_code_parts, num_guards_executed);
}
}
PyObject* value = PyDict_GetItem(obj, key); // borrowed ref
std::unique_ptr<GuardManager>& value_manager = key_value_manager.second;
if (value_manager) {
GuardDebugInfo debug_info =
value_manager->check_verbose_nopybind(value);
num_guards_executed += debug_info.num_guards_executed;
if (!debug_info.result) {
Py_DECREF(key);
Py_DECREF(iterator);
return GuardDebugInfo(
false, debug_info.verbose_code_parts, num_guards_executed);
}
}
index_pointer++;
}
Py_DECREF(key);
dict_pointer++;
}
Py_DECREF(iterator);
return GuardDebugInfo(true, num_guards_executed);
}
public: // cloning functions
DictSubclassGuardManager(
RootGuardManager* cloned_root,
std::string source,
Py_ssize_t size,
PyTypeObject* _expected_type,
bool is_exact_dict_type,
std::vector<Py_ssize_t> indices)
: DictGuardManager(
cloned_root,
std::move(source),
size,
_expected_type,
is_exact_dict_type,
std::move(indices)) {}
GuardManager* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_dict_guard_manager<DictSubclassGuardManager>(
cloned_root, clone_filter_fn);
}
};
GuardManager* clone_guard_manager(
GuardManager* from,
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) {
return from->clone(cloned_root, clone_filter_fn);
}
void add_relational_guard_resetter_to_cloned_root(
RootGuardManager* root,
std::shared_ptr<RelationalGuard> guard) {
root->add_relational_guard_resetter(std::move(guard));
}
std::unique_ptr<GuardManager> make_guard_manager(
RootGuardManager* root,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) {
#if IS_PYBIND_2_13_PLUS
using fourobjects =
std::tuple<py::object, py::object, py::object, py::object>;
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<fourobjects>
storage;
auto& [guard_manager_enum_class, base_guard_manager_enum, dict_guard_manager_enum, dict_subclass_guard_manager_enum] =
storage
.call_once_and_store_result([]() -> fourobjects {
py::object guard_manager_enum_class =
py::module_::import("torch._dynamo.guards")
.attr("GuardManagerType");
return {
guard_manager_enum_class,
guard_manager_enum_class.attr("GUARD_MANAGER"),
guard_manager_enum_class.attr("DICT_GUARD_MANAGER"),
guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER")};
})
.get_stored();
#else
static py::object guard_manager_enum_class =
py::module_::import("torch._dynamo.guards").attr("GuardManagerType");
static py::object base_guard_manager_enum =
guard_manager_enum_class.attr("GUARD_MANAGER");
static py::object dict_guard_manager_enum =
guard_manager_enum_class.attr("DICT_GUARD_MANAGER");
static py::object dict_subclass_guard_manager_enum =
guard_manager_enum_class.attr("DICT_SUBCLASS_GUARD_MANAGER");
#endif
if (py::isinstance<py::dict>(example_value)) {
// The purpose of having both DictGuardManager and DictSubclassGuardManager
// is to handle the variability in how dictionaries and their subclasses
// manage key ordering.
// While inserting dictionary guards (check guards.py), we rely on the
// list(d.keys()) ordering. Therefore, the cpp guard equivalent must have
// the same keys ordering. For standard dictionaries, .keys() API internally
// uses PyDict_Next. So, DictGuardManager directly uses PyDict_Next to
// speedup the key fetches.
// But PyDict_Next might not give correct ordering for subclasses of dict.
// For example, OrderedDict override the .keys() API without changing the
// underlying datastructure. This leads to different keys ordering than the
// one given by PyDict_Next. We use DictSubclassGuardManager to account for
// this discrepancy. DictSubclassGuardManager directly calls the .keys() API
// to accurately capture key ordering. This approach is less efficient than
// using PyDict_Next (handled by DictGuardManager), but it ensures
// correctness.
// Since regular dicts are more common than subclasses of dicts with
// overridden keys method, we still optimize for the common case with
// DictGuardManager by relying on PyDict_Next.
if (guard_manager_enum.is(base_guard_manager_enum)) {
// For dicts that don't need to guard on keys, we can just rely on the
// base GuardManager.
return std::make_unique<GuardManager>(
root, std::move(source), example_value);
} else if (guard_manager_enum.is(dict_guard_manager_enum)) {
return std::make_unique<DictGuardManager>(
root, std::move(source), example_value);
} else if (guard_manager_enum.is(dict_subclass_guard_manager_enum))
return std::make_unique<DictSubclassGuardManager>(
root, std::move(source), example_value);
else {
throw py::type_error("Invalid guard manager enum");
}
}
return std::make_unique<GuardManager>(root, std::move(source));
}
class TORCH_FUNCTION_MODE_STACK : public LeafGuard {
public:
TORCH_FUNCTION_MODE_STACK(
const py::list& initial_stack,
py::object verbose_code_parts)
: LeafGuard(std::move(verbose_code_parts)), _ref_stack() {
Py_ssize_t len = PyList_Size(initial_stack.ptr());
for (Py_ssize_t idx = 0; idx < len; idx++) {
PyObject* mode = PyList_GetItem(initial_stack.ptr(), idx); // borrowed ref
auto type = Py_TYPE(mode);
this->_ref_stack.push_back(type);
}
}
bool check_nopybind(PyObject* value) override {
// Ignore value arg, only used to satisfy the interface
const size_t len = (size_t)at::impl::PythonTorchFunctionTLS::stack_len();
const size_t ref_stack_size = this->_ref_stack.size();
if (len != ref_stack_size) {
return false;
}
for (int64_t idx = 0; (size_t)idx < len; idx++) {
std::shared_ptr<c10::SafePyObject> mode =
at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
PyTypeObject* mode_type = Py_TYPE(mode->ptr(getPyInterpreter()));
if (mode_type != _ref_stack.at(idx)) {
return false;
}
}
return true;
}
private:
std::vector<PyTypeObject*> _ref_stack;
};
class TENSOR_MATCH : public LeafGuard {
public:
TENSOR_MATCH(
RootGuardManager* root_guard_manager,
py::object value,
py::object dynamic_dims_sizes_py,
py::object dynamic_dims_strides_py,
py::object tensor_name,
py::object verbose_code_parts)
: LeafGuard(root_guard_manager, std::move(verbose_code_parts)),
_tensor_name(py::cast<std::string>(std::move(tensor_name))) {
root_guard_manager->set_init_local_state_flag();
PyObject* item = value.ptr();
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
return;
}
auto tensor = THPVariable_Unpack(item);
std::vector<std::optional<c10::SymInt>> tensor_dims_size =
pyListToVecOptInt(dynamic_dims_sizes_py.ptr());
std::vector<std::optional<c10::SymInt>> tensor_dims_stride =
pyListToVecOptInt(dynamic_dims_strides_py.ptr());
tensor_dims_size = tensor_dims_size.empty()
? wrapIntegersInOptional(tensor.sym_sizes())
: tensor_dims_size;
tensor_dims_stride = tensor_dims_stride.empty()
? wrapIntegersInOptional(tensor.sym_strides())
: tensor_dims_stride;
LocalState state;
_tensor_check = std::make_unique<TensorCheck>(
state,
Py_TYPE(item),
std::move(tensor),
std::move(tensor_dims_size),
std::move(tensor_dims_stride));
}
bool check_nopybind(PyObject* value) override { // borrowed ref
if (Py_TYPE(value) != _tensor_check->pytype) {
return false;
}
return _tensor_check->check(
_root_guard_manager->_local_state, THPVariable_Unpack(value));
}
GuardDebugInfo check_verbose_nopybind(
PyObject* value) override { // borrowed ref
if (Py_TYPE(value) != _tensor_check->pytype) {
std::stringstream fail_reason;
PyObject* type_str = PyObject_Str(PyObject_Type(value));
fail_reason << "expected type of '" << _tensor_name
<< "' to be a tensor type, ";
if (!type_str) {
fail_reason << "but found a different type";
} else {
fail_reason << "' but found " << PyUnicode_AsUTF8(type_str);
}
return GuardDebugInfo(false, fail_reason.str(), 0);
}
std::string fail_reason = _tensor_check->check_verbose(
_root_guard_manager->_local_state,
THPVariable_Unpack(value),
_tensor_name);
if (!fail_reason.empty()) {
if (is_parameter(py::handle(value))) {
fail_reason += ". Guard failed on a parameter, consider using ";
fail_reason +=
"torch._dynamo.config.force_parameter_static_shapes = False ";
fail_reason += "to allow dynamism on parameters.";
}
return GuardDebugInfo(false, fail_reason, 0);
}
return GuardDebugInfo(true, 1);
}
private:
std::string _tensor_name;
std::unique_ptr<TensorCheck> _tensor_check;
};
/**
* Represents __getattr__ acccessor.
*/
class GetAttrGuardAccessor : public GuardAccessor {
public:
GetAttrGuardAccessor(
RootGuardManager* root,
py::str name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
name,
std::move(source),
example_value,
guard_manager_enum),
_attr_name(name.ptr()) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* x = PyObject_GetAttr(obj, _attr_name); // new ref
if (x == nullptr) {
// Attribute absent, clear the exception and return false.
PyErr_Clear();
return false;
}
bool result = _guard_manager->check_nopybind(x);
Py_DECREF(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* x = PyObject_GetAttr(obj, _attr_name); // new ref
if (x == nullptr) {
// Attribute absent, clear the exception and return false.
PyErr_Clear();
return GuardDebugInfo(
false, "getattr failed on source " + get_source(), 0);
}
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
Py_DECREF(x);
return result;
}
std::string repr() const override {
// Helpful when priting GuardManager tree structure.
return "GetAttrGuardAccessor(" + py::str(_attr_name).cast<std::string>() +
")";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GetAttrGuardAccessor(GuardManager* guard_manager, GetAttrGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<GetAttrGuardAccessor>(cloned_root, clone_filter_fn);
}
void clone_visitor(GetAttrGuardAccessor* to) {
to->_attr_name = _attr_name;
}
private:
// no need of py::object here because the attr_name is already passed on to
// the base class as accessor_key which is a py::object.
PyObject* _attr_name;
};
/**
* Represents x.__dict__ acccessor.
*/
class GetGenericDictGuardAccessor : public GuardAccessor {
public:
GetGenericDictGuardAccessor(
RootGuardManager* root,
py::str name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
std::move(name),
std::move(source),
example_value,
guard_manager_enum) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* x = PyObject_GenericGetDict(obj, nullptr); // new ref
if (x == nullptr) {
// Attribute absent, clear the exception and return false.
PyErr_Clear();
return false;
}
bool result = _guard_manager->check_nopybind(x);
Py_DECREF(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* x = PyObject_GenericGetDict(obj, nullptr); // new ref
if (x == nullptr) {
// Attribute absent, clear the exception and return false.
PyErr_Clear();
return GuardDebugInfo(
false, "getattr failed on source " + get_source(), 0);
}
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
Py_DECREF(x);
return result;
}
std::string repr() const override {
// Helpful when priting GuardManager tree structure.
return "GetGenericDictGuardAccessor";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GetGenericDictGuardAccessor(
GuardManager* guard_manager,
GetGenericDictGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<GetGenericDictGuardAccessor>(
cloned_root, clone_filter_fn);
}
};
/**
* Represents __getitem__ acccessor.
*/
class GetItemGuardAccessor : public GuardAccessor {
public:
GetItemGuardAccessor(
RootGuardManager* root,
py::object name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
name,
std::move(source),
example_value,
guard_manager_enum),
_attr_name(name.ptr()) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* x = PyObject_GetItem(obj, _attr_name); // new ref
if (x == nullptr) {
PyErr_Clear();
return false;
}
bool result = _guard_manager->check_nopybind(x);
Py_DECREF(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* x = PyObject_GetItem(obj, _attr_name); // new ref
if (x == nullptr) {
PyErr_Clear();
return GuardDebugInfo(
false, std::string("KeyError on ") + get_source(), 0);
}
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
Py_DECREF(x);
return result;
}
std::string repr() const override {
return "GetItemGuardAccessor(" + py::str(_attr_name).cast<std::string>() +
")";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GetItemGuardAccessor(GuardManager* guard_manager, GetItemGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<GetItemGuardAccessor>(cloned_root, clone_filter_fn);
}
void clone_visitor(GetItemGuardAccessor* to) {
to->_attr_name = _attr_name;
}
private:
// no need of py::object here because the attr_name is already passed on to
// the base class as accessor_key which is a py::object.
PyObject* _attr_name;
};
/**
* Represents dict[name] acccessor. This is ONLY used for f_locals because its a
* dict, and DictGuardManager does not support sorting. We differentiate it from
* GetItemGuardAccessor because PyDict_GetItem should be fasten the
* PyObject_GetItem.
*/
class DictGetItemGuardAccessor : public GuardAccessor {
public:
DictGetItemGuardAccessor(
RootGuardManager* root,
py::object key,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
key,
std::move(source),
example_value,
guard_manager_enum),
_key(key.ptr()),
_is_immutable_object(is_immutable_object(example_value)) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
if (matches_dict_tag && _is_immutable_object) {
// immutable object and dict tag matches, we can skip the guard subtree.
return true;
}
PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref
if (x == nullptr) {
PyErr_Clear();
return false;
}
bool result = _guard_manager->check_nopybind(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* x = PyDict_GetItem(obj, _key); // borrowed ref
if (x == nullptr) {
PyErr_Clear();
return GuardDebugInfo(
false, std::string("KeyError on ") + get_source(), 0);
}
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
return result;
}
std::string repr() const override {
return "DictGetItemGuardAccessor(" + py::repr(_key).cast<std::string>() +
")";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
DictGetItemGuardAccessor(
GuardManager* guard_manager,
DictGetItemGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<DictGetItemGuardAccessor>(cloned_root, clone_filter_fn);
}
void clone_visitor(DictGetItemGuardAccessor* to) {
to->_key = _key;
to->_is_immutable_object = _is_immutable_object;
}
private:
PyObject* _key;
// If immutable object and dict tag matches, we can skip the guard subtree and
// return true.
bool _is_immutable_object;
};
/**
* Represents list[index] accessor. It is faster than generic
* GetItemGuardAccessor.
*/
class ListGetItemGuardAccessor : public GuardAccessor {
public:
ListGetItemGuardAccessor(
RootGuardManager* root,
const py::object& index,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
index,
std::move(source),
example_value,
guard_manager_enum),
_index(py::cast<Py_ssize_t>(index)) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* x = PyList_GetItem(obj, _index); // borrowed ref
if (x == nullptr) {
PyErr_Clear();
return false;
}
bool result = _guard_manager->check_nopybind(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* x = PyList_GetItem(obj, _index); // borrowed ref
if (x == nullptr) {
PyErr_Clear();
return GuardDebugInfo(
false, std::string("IndexError on ") + get_source(), 0);
}
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
return result;
}
std::string repr() const override {
return "ListGetItemGuardAccessor(" + std::to_string(_index) + ")";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
ListGetItemGuardAccessor(
GuardManager* guard_manager,
ListGetItemGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<ListGetItemGuardAccessor>(cloned_root, clone_filter_fn);
}
void clone_visitor(ListGetItemGuardAccessor* to) {
to->_index = _index;
}
private:
Py_ssize_t _index;
};
/**
* Represents tuple[index] accessor. It is faster than generic
* GetItemGuardAccessor.
*/
class TupleGetItemGuardAccessor : public GuardAccessor {
public:
TupleGetItemGuardAccessor(
RootGuardManager* root,
const py::object& index,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
index,
std::move(source),
example_value,
guard_manager_enum),
_index(py::cast<Py_ssize_t>(index)) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* x = PyTuple_GetItem(obj, _index); // borrowed ref
if (x == nullptr) {
PyErr_Clear();
return false;
}
bool result = _guard_manager->check_nopybind(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* x = PyTuple_GetItem(obj, _index); // borrowed ref
if (x == nullptr) {
PyErr_Clear();
return GuardDebugInfo(
false, std::string("IndexError on ") + get_source(), 0);
}
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
return result;
}
std::string repr() const override {
return "TupleGetItemGuardAccessor(" + std::to_string(_index) + ")";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TupleGetItemGuardAccessor(
GuardManager* guard_manager,
TupleGetItemGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<TupleGetItemGuardAccessor>(
cloned_root, clone_filter_fn);
}
void clone_visitor(TupleGetItemGuardAccessor* to) {
to->_index = _index;
}
private:
Py_ssize_t _index;
};
/**
* Represents tensor.grad acccessor.
*/
class GradGuardAccessor : public GuardAccessor {
public:
GradGuardAccessor(
RootGuardManager* root,
py::str name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
std::move(name),
std::move(source),
example_value,
guard_manager_enum) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
// check that its a tensor
if (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)) {
return false;
}
PyObject* grad =
THPVariable_Wrap(THPVariable_Unpack(obj).grad()); // New reference
bool result = _guard_manager->check_nopybind(grad);
// For undefined tensor, THPVariable_Wrap returns Py_RETURN_NONE. So, no
// need of Py_XDECREF.
Py_DECREF(grad);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
// check that its a tensor
if (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)) {
return GuardDebugInfo(
false, "not a tensor - grad field is accessed " + get_source(), 0);
}
PyObject* grad =
THPVariable_Wrap(THPVariable_Unpack(obj).grad()); // New reference
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(grad);
// For undefined tensor, THPVariable_Wrap returns Py_RETURN_NONE. So, no
// need of Py_XDECREF.
Py_DECREF(grad);
return result;
}
std::string repr() const override {
// Helpful when priting GuardManager tree structure.
return "GradGuardAccessor(grad)";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GradGuardAccessor(GuardManager* guard_manager, GradGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<GradGuardAccessor>(cloned_root, clone_filter_fn);
}
};
/**
* Represents func.__defaults__ accessor.
*/
class FuncDefaultsGuardAccessor : public GuardAccessor {
public:
FuncDefaultsGuardAccessor(
RootGuardManager* root,
py::object name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
std::move(name),
std::move(source),
example_value,
guard_manager_enum) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* func = obj;
if (PyMethod_Check(obj)) {
func = PyMethod_GET_FUNCTION(obj); // borrowed ref
} else if (PyInstanceMethod_Check(obj)) {
func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
}
PyObject* x = PyFunction_GetDefaults(func); // borrowed ref
if (x == nullptr) {
PyErr_Clear();
return false;
}
return _guard_manager->check_nopybind(x);
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* func = obj;
if (PyMethod_Check(obj)) {
func = PyMethod_GET_FUNCTION(obj); // borrowed ref
} else if (PyInstanceMethod_Check(obj)) {
func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
}
PyObject* x = PyFunction_GetDefaults(func);
if (x == nullptr) {
PyErr_Clear();
return GuardDebugInfo(
false,
std::string(repr() + ": Not a function on ") + get_source(),
0);
}
return _guard_manager->check_verbose_nopybind(x);
}
std::string repr() const override {
return "FuncDefaultsGuardAccessor";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
FuncDefaultsGuardAccessor(
GuardManager* guard_manager,
FuncDefaultsGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<FuncDefaultsGuardAccessor>(
cloned_root, clone_filter_fn);
}
};
/**
* Represents func.__kwdefaults__ accessor.
*/
class FuncKwDefaultsGuardAccessor : public GuardAccessor {
public:
FuncKwDefaultsGuardAccessor(
RootGuardManager* root,
py::object name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
std::move(name),
std::move(source),
example_value,
guard_manager_enum) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* func = obj;
if (PyMethod_Check(obj)) {
func = PyMethod_GET_FUNCTION(obj); // borrowed ref
} else if (PyInstanceMethod_Check(obj)) {
func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
}
PyObject* x = PyFunction_GetKwDefaults(func); // borrowed ref
if (x == nullptr) {
PyErr_Clear();
return false;
}
return _guard_manager->check_nopybind(x);
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* func = obj;
if (PyMethod_Check(obj)) {
func = PyMethod_GET_FUNCTION(obj); // borrowed ref
} else if (PyInstanceMethod_Check(obj)) {
func = PyInstanceMethod_GET_FUNCTION(obj); // borrowed ref
}
PyObject* x = PyFunction_GetKwDefaults(func);
if (x == nullptr) {
PyErr_Clear();
return GuardDebugInfo(
false,
std::string(repr() + ": Not a function on ") + get_source(),
0);
}
return _guard_manager->check_verbose_nopybind(x);
}
std::string repr() const override {
return "FuncKwDefaultsGuardAccessor";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
FuncKwDefaultsGuardAccessor(
GuardManager* guard_manager,
FuncKwDefaultsGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<FuncKwDefaultsGuardAccessor>(
cloned_root, clone_filter_fn);
}
};
/**
* Represents f_globals acccessor. This sits as a child accessor of the
* RootGuardManager.
*/
class GlobalsGuardAccessor : public GuardAccessor {
public:
GlobalsGuardAccessor(
RootGuardManager* root,
py::dict globals_dict,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
globals_dict,
std::move(source),
example_value,
guard_manager_enum),
_globals_dict(globals_dict.ptr()) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
// Ignore the obj arg. This is required to satisfy the function signature.
// Just pass on the globals dict to the child manager.
return _guard_manager->check_nopybind(_globals_dict);
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
// Ignore the obj arg. This is required to satisfy the function signature.
// Just pass on the globals dict to the child manager.
return _guard_manager->check_verbose_nopybind(_globals_dict);
}
std::string repr() const override {
return "GlobalsGuardAccessor";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GlobalsGuardAccessor(GuardManager* guard_manager, GlobalsGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<GlobalsGuardAccessor>(cloned_root, clone_filter_fn);
}
void clone_visitor(GlobalsGuardAccessor* to) {
to->_globals_dict = _globals_dict;
}
private:
// no need of py::object here because the globals_dict is already passed on to
// the base class as accessor_key which is a py::object.
PyObject* _globals_dict;
};
/**
* Represent type(...) accessor.
*/
class TypeGuardAccessor : public GuardAccessor {
public:
// name = __type_accessor__, a unique string used as attribute name.
TypeGuardAccessor(
RootGuardManager* root,
py::str name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
std::move(name),
std::move(source),
example_value,
guard_manager_enum) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* x = (PyObject*)Py_TYPE(obj); // borrowed ref
return _guard_manager->check_nopybind(x);
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* x = (PyObject*)Py_TYPE(obj); // borrowed ref
return _guard_manager->check_verbose_nopybind(x);
}
std::string repr() const override {
return "TypeGuardAccessor";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TypeGuardAccessor(GuardManager* guard_manager, TypeGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<TypeGuardAccessor>(cloned_root, clone_filter_fn);
}
void clone_visitor(TypeGuardAccessor* to) {}
};
/**
* Getitem tuple_iterator accessor.
*/
class TupleIteratorGetItemAccessor : public GuardAccessor {
public:
TupleIteratorGetItemAccessor(
RootGuardManager* root,
py::object index,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
index,
std::move(source),
example_value,
guard_manager_enum),
_index(py::cast<Py_ssize_t>(std::move(index))) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
_PyTupleIterObject* it = (_PyTupleIterObject*)obj;
PyObject* x =
PyTuple_GET_ITEM(it->it_seq, it->it_index + _index); // borrowed ref
if (x == nullptr) {
// Out of range.
PyErr_Clear();
return false;
}
bool result = _guard_manager->check_nopybind(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
_PyTupleIterObject* it = (_PyTupleIterObject*)obj;
PyObject* x =
PyTuple_GET_ITEM(it->it_seq, it->it_index + _index); // borrowed ref
if (x == nullptr) {
// Out of range.
PyErr_Clear();
return GuardDebugInfo(false, std::string("IndexError ") + repr(), 0);
}
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
return result;
}
std::string repr() const override {
return "TupleIteratorGetItemAccessor(" + std::to_string(_index) + ")";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
TupleIteratorGetItemAccessor(
GuardManager* guard_manager,
TupleIteratorGetItemAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<TupleIteratorGetItemAccessor>(
cloned_root, clone_filter_fn);
}
void clone_visitor(TupleIteratorGetItemAccessor* to) {
to->_index = _index;
}
private:
Py_ssize_t _index;
};
/**
* GlobalWeakRef accessor. Dynamo can insert a weakref object into the frame
* globals. This accessor reads the globals and then calls the weakref object
* to get the underlying object. This is a child of GlobalsGuardAccessor.
* Therefore, we will get the globals dict while caling check_nopybind.
*/
class GlobalWeakRefGuardAccessor : public GuardAccessor {
public:
GlobalWeakRefGuardAccessor(
RootGuardManager* root,
py::object global_name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
global_name,
std::move(source),
example_value,
guard_manager_enum),
_global_name(global_name.ptr()) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
// obj is globals dict because GlobalWeakRefGuardAccessor has to be a
// child of GlobalsGuardAccessor.
PyObject* weakref = PyDict_GetItem(obj, _global_name); // borrowed ref
if (weakref == nullptr) {
// The weakref is not in the globals dict.
PyErr_Clear();
return false;
}
if (!PyWeakref_Check(weakref)) {
return false;
}
PyObject* x = nullptr;
if (PyWeakref_GetRef(weakref, &x) == -1) { // strong reference
// error when attempting to call ref
PyErr_Clear();
return false;
}
if (x == nullptr) {
// weakref is dead
x = Py_NewRef(Py_None);
}
bool result = _guard_manager->check_nopybind(x);
Py_DECREF(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
// obj is globals dict because GlobalWeakRefGuardAccessor has to be a
// child of GlobalsGuardAccessor.
PyObject* weakref = PyDict_GetItem(obj, _global_name); // borrowed ref
if (weakref == nullptr) {
// The weakref is not in the globals dict.
PyErr_Clear();
return GuardDebugInfo(
false, std::string("KeyError on ") + get_source(), 0);
}
if (!PyWeakref_Check(weakref)) {
return GuardDebugInfo(
false, std::string("Not a weakref ") + get_source(), 0);
}
PyObject* x = nullptr;
if (PyWeakref_GetRef(weakref, &x) == -1) { // strong reference
// error when attempting to call ref
PyErr_Clear();
return GuardDebugInfo(
false, std::string("Weakref_GetRef failed ") + get_source(), 0);
}
if (x == nullptr) {
// weakref is dead
x = Py_NewRef(Py_None);
}
auto result = _guard_manager->check_verbose_nopybind(x);
Py_DECREF(x);
return result;
}
std::string repr() const override {
return "GlobalWeakRefGuardAccessor(" +
py::str(_global_name).cast<std::string>() + ")";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
GlobalWeakRefGuardAccessor(
GuardManager* guard_manager,
GlobalWeakRefGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<GlobalWeakRefGuardAccessor>(
cloned_root, clone_filter_fn);
}
void clone_visitor(GlobalWeakRefGuardAccessor* to) {
to->_global_name = _global_name;
}
private:
PyObject* _global_name;
};
/**
* Implements weakref call - x_weak()
*/
class WeakRefCallGuardAccessor : public GuardAccessor {
public:
WeakRefCallGuardAccessor(
RootGuardManager* root,
py::str name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
std::move(name),
std::move(source),
example_value,
guard_manager_enum) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
if (!PyWeakref_Check(obj)) {
return false;
}
PyObject* x = nullptr;
if (PyWeakref_GetRef(obj, &x) == -1) { // strong reference
// error when attempting to call ref
PyErr_Clear();
return false;
}
if (x == nullptr) {
// weakref is dead
x = Py_NewRef(Py_None);
}
bool result = _guard_manager->check_nopybind(x);
Py_DECREF(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
if (!PyWeakref_Check(obj)) {
return GuardDebugInfo(
false, std::string("Not a weakref obj ") + get_source(), 0);
}
PyObject* x = nullptr;
if (PyWeakref_GetRef(obj, &x) == -1) { // strong reference
// error when attempting to call ref
PyErr_Clear();
return GuardDebugInfo(
false, std::string("Weakref_GetRef failed ") + get_source(), 0);
}
if (x == nullptr) {
// weakref is dead
x = Py_NewRef(Py_None);
}
auto result = _guard_manager->check_verbose_nopybind(x);
Py_DECREF(x);
return result;
}
std::string repr() const override {
return "WeakRefCallGuardAccessor()";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
WeakRefCallGuardAccessor(
GuardManager* guard_manager,
WeakRefCallGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<WeakRefCallGuardAccessor>(cloned_root, clone_filter_fn);
}
void clone_visitor(WeakRefCallGuardAccessor* to) {}
};
/**
* Implements function call no args - e.g, torch.cuda.current_device()
*/
class CallFunctionNoArgsGuardAccessor : public GuardAccessor {
public:
CallFunctionNoArgsGuardAccessor(
RootGuardManager* root,
py::str name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
std::move(name),
std::move(source),
example_value,
guard_manager_enum) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
if (!PyCallable_Check(obj)) {
return false;
}
PyObject* x = PyObject_CallNoArgs(obj);
if (x == nullptr) {
// Call failed, clear the exception and return false.
PyErr_Clear();
return false;
}
bool result = _guard_manager->check_nopybind(x);
Py_DECREF(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
if (!PyCallable_Check(obj)) {
return GuardDebugInfo(
false, std::string("Not a callable obj ") + get_source(), 0);
}
PyObject* x = PyObject_CallNoArgs(obj);
if (x == nullptr) {
// Call failed, clear the exception and return debug info.
std::string exc_message = get_exception_message();
PyErr_Clear();
return GuardDebugInfo(false, exc_message, 0);
}
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
Py_DECREF(x);
return result;
}
std::string repr() const override {
return "CallFunctionNoArgsGuardAccessor()";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
CallFunctionNoArgsGuardAccessor(
GuardManager* guard_manager,
CallFunctionNoArgsGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<CallFunctionNoArgsGuardAccessor>(
cloned_root, clone_filter_fn);
}
void clone_visitor(CallFunctionNoArgsGuardAccessor* to) {}
};
/**
* Similar to PythonLambdaLeafGuard, this class is a way to allow developers to
* supply accessor as a python function. This is useful for from_numpy source.
*/
class PythonLambdaGuardAccessor : public GuardAccessor {
public:
PythonLambdaGuardAccessor(
RootGuardManager* root,
py::function accessor_fn,
std::string source,
py::handle example_value,
py::handle guard_manager_enum)
: GuardAccessor(
root,
accessor_fn,
std::move(source),
example_value,
guard_manager_enum),
_accessor_fn(std::move(accessor_fn)) {}
// NB: Intentional duplication between check_nopybind and
// check_verbose_nopybind.
bool check_nopybind(PyObject* obj, bool matches_dict_tag = false)
override { // borrowed ref
PyObject* x = PyObject_CallOneArg(_accessor_fn.ptr(), obj); // new ref
if (x == nullptr) {
// The accessor function failed.
PyErr_Clear();
return false;
}
bool result = _guard_manager->check_nopybind(x);
Py_DECREF(x);
return result;
}
GuardDebugInfo check_verbose_nopybind(
PyObject* obj) override { // borrowed ref
PyObject* x = PyObject_CallOneArg(_accessor_fn.ptr(), obj); // new ref
if (x == nullptr) {
// The accessor function failed.
std::string exc_message = get_exception_message();
PyErr_Clear();
return GuardDebugInfo(false, exc_message, 0);
}
GuardDebugInfo result = _guard_manager->check_verbose_nopybind(x);
Py_DECREF(x);
return result;
}
std::string repr() const override {
return "PythonLambdaGuardAccessor";
}
public: // cloning functions
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
PythonLambdaGuardAccessor(
GuardManager* guard_manager,
PythonLambdaGuardAccessor* from)
: GuardAccessor(guard_manager, from) {
from->clone_visitor(this);
}
GuardAccessor* clone(
RootGuardManager* cloned_root,
const py::function& clone_filter_fn) override {
return clone_common<PythonLambdaGuardAccessor>(
cloned_root, clone_filter_fn);
}
void clone_visitor(PythonLambdaGuardAccessor* to) {
to->_accessor_fn = _accessor_fn;
}
private:
py::object _accessor_fn;
};
void install_object_aliasing_guard(
GuardManager* x,
GuardManager* y,
py::object verbose_code_parts) {
// Adds tensor X is tensor Y guard. This is a an example of relational guard.
// There is one guard object that is shared between two guard managers.
std::shared_ptr<RelationalGuard> guard =
std::make_shared<OBJECT_ALIASING>(std::move(verbose_code_parts));
// Register the resetter on the toor guard mananger, so that it can reset
// the newly added relational guard when the guard eval fails.
x->get_root()->add_relational_guard_resetter(guard);
// In case the guard is a DictGuardManager, OBJECT_ALIASING guard is a
// permitted guard.
x->add_permitted_leaf_guard(guard);
y->add_permitted_leaf_guard(guard);
}
void install_no_tensor_aliasing_guard(
const py::list& guard_managers,
const py::list& tensor_names,
py::object verbose_code_parts) {
// Adds a guard that checks none of tensors alias. This is a an example of
// relational guard. There is one guard object that is shared between multiple
// guard managers.
std::shared_ptr<RelationalGuard> guard = std::make_shared<NO_TENSOR_ALIASING>(
tensor_names, std::move(verbose_code_parts));
// Register the resetter on the toor guard mananger, so that it can reset
// the newly added relational guard when the guard eval fails.
py::cast<GuardManager*>(guard_managers[0])
->get_root()
->add_relational_guard_resetter(guard);
for (const auto& guard_manager : guard_managers) {
py::cast<GuardManager*>(guard_manager)->add_leaf_guard(guard);
}
}
void install_storage_overlapping_guard_with_checker(
std::shared_ptr<StorageOverlapChecker> checker,
const py::list& guard_managers,
py::object verbose_code_parts,
bool overlapping) {
if (guard_managers.size() == 0) {
// If there are no GuardManagers, there's no need to create a
// STORAGE_OVERLAPPING guard.
return;
}
std::shared_ptr<RelationalGuard> guard =
std::make_shared<STORAGE_OVERLAPPING>(
overlapping, checker, verbose_code_parts);
py::cast<GuardManager*>(guard_managers[0])
->get_root()
->add_relational_guard_resetter(guard);
for (const auto& guard_manager : guard_managers) {
py::cast<GuardManager*>(guard_manager)->add_leaf_guard(guard);
}
}
void install_storage_overlapping_guard(
const py::list& overlapping_guard_managers,
const py::list& non_overlapping_guard_managers,
py::object verbose_code_parts) {
// Create a single StorageOverlapChecker that will be shared amongst
// the 2 STORAGE_OVERLAPPING guards below.
std::shared_ptr<StorageOverlapChecker> checker =
std::make_shared<StorageOverlapChecker>(
overlapping_guard_managers.size(),
non_overlapping_guard_managers.size());
// Create the possibly overlapping storage guard.
install_storage_overlapping_guard_with_checker(
checker,
overlapping_guard_managers,
verbose_code_parts,
/* overlapping= */ true);
// Create the non-overlapping storage guard.
install_storage_overlapping_guard_with_checker(
checker,
non_overlapping_guard_managers,
verbose_code_parts,
/* overlapping= */ false);
}
double profile_guard_manager(RootGuardManager* root, py::object f_locals) {
PyObject* locals = f_locals.ptr();
// Warmup
for (int i = 0; i < 10; i++) {
root->check_nopybind(locals);
}
int count = 0;
auto start = std::chrono::high_resolution_clock::now();
float profile_duration = 1.0;
// Run the loop for profile_duration seconds
while (true) {
root->check_nopybind(locals);
count++;
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = end - start;
// Break the loop if 1 second has passed
if (elapsed.count() >= 1.0) {
break;
}
}
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> total_elapsed = end - start;
// Calculate the average time per iteration in microseconds
return (total_elapsed.count() * profile_duration * 1e6) / count;
}
} // namespace
static void* _torchinductor_pyobject_tensor_data_ptr(PyObject* obj) {
if (C10_UNLIKELY(
obj == nullptr ||
(!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)))) {
throw std::runtime_error(
"_torchinductor_pyobject_tensor_data_ptr: non-tensor input");
}
return THPVariable_Unpack(obj).data_ptr();
}
void* convert_to_root_guard_manager(py::object root) {
// For invalidated guards, return nullptr
if (root.is(py::none())) {
return nullptr;
}
RootGuardManager* root_mgr = std::move(root).cast<RootGuardManager*>();
return (void*)root_mgr;
}
bool run_root_guard_manager(void* root, PyObject* f_locals) {
// for invalidated guards, return false
if (root == nullptr) {
return false;
}
return ((RootGuardManager*)root)->check_nopybind(f_locals);
}
PyObject* torch_c_dynamo_guards_init() {
// initialize TensorGuardsType
TensorGuardsType.tp_name = "torch._C._dynamo.guards.TensorGuards";
TensorGuardsType.tp_basicsize = sizeof(TensorGuards);
TensorGuardsType.tp_itemsize = 0;
TensorGuardsType.tp_dealloc = (destructor)TensorGuards_dealloc;
TensorGuardsType.tp_flags = Py_TPFLAGS_DEFAULT;
TensorGuardsType.tp_doc = "Check properties of a torch.Tensor";
TensorGuardsType.tp_methods = TensorGuards_methods;
TensorGuardsType.tp_init = (initproc)TensorGuards_init;
TensorGuardsType.tp_new = TensorGuards_new;
if (PyType_Ready(&TensorGuardsType) < 0)
return nullptr;
GlobalStateGuardType.tp_name = "torch._C._dynamo.guards.GlobalStateGuard";
GlobalStateGuardType.tp_basicsize = sizeof(GlobalStateGuard);
GlobalStateGuardType.tp_itemsize = 0;
GlobalStateGuardType.tp_flags = Py_TPFLAGS_DEFAULT;
GlobalStateGuardType.tp_doc = "Guard on PyTorch global flags such as no_grad";
GlobalStateGuardType.tp_methods = GlobalStateGuard_methods;
GlobalStateGuardType.tp_init = (initproc)GlobalStateGuard_init;
GlobalStateGuardType.tp_new = PyType_GenericNew;
if (PyType_Ready(&GlobalStateGuardType) < 0)
return nullptr;
auto m = PyModule_Create(&_module);
if (m == nullptr)
return nullptr;
#ifdef Py_GIL_DISABLED
PyUnstable_Module_SetGIL(m, Py_MOD_GIL_NOT_USED);
#endif
Py_INCREF(&TensorGuardsType);
if (PyModule_AddObject(m, "TensorGuards", (PyObject*)&TensorGuardsType) < 0) {
Py_DECREF(&TensorGuardsType);
Py_DECREF(m);
return nullptr;
}
Py_INCREF(&GlobalStateGuardType);
if (PyModule_AddObject(
m, "GlobalStateGuard", (PyObject*)&GlobalStateGuardType) < 0) {
Py_DECREF(&GlobalStateGuardType);
Py_DECREF(m);
return nullptr;
}
// We expose the address of _torchinductor_pyobject_tensor_data_ptr in order
// to allow manual linking in our generated TorchInductor Python bindings.
// While regular linking works in most cases, it does not work properly in
// fbcode due to janky build setup there.
if (PyModule_AddObject(
m,
"_torchinductor_pyobject_tensor_data_ptr",
PyLong_FromVoidPtr(reinterpret_cast<void*>(
&_torchinductor_pyobject_tensor_data_ptr))) < 0) {
return nullptr;
}
auto py_m = py::handle(m).cast<py::module>();
py::class_<GuardDebugInfo, std::unique_ptr<GuardDebugInfo>>(
py_m, "GuardDebugInfo")
.def(py::init<bool, py::list, int>())
.def("__str__", &GuardDebugInfo::to_string)
.def_readonly("result", &GuardDebugInfo::result)
.def_readonly("verbose_code_parts", &GuardDebugInfo::verbose_code_parts)
.def_readonly(
"num_guards_executed", &GuardDebugInfo::num_guards_executed);
// Leaf Guards
py::class_<LeafGuard, std::shared_ptr<LeafGuard>>(py_m, "LeafGuard")
.def("verbose_code_parts", &LeafGuard::verbose_code_parts);
py::class_<LAMBDA_GUARD, LeafGuard, std::shared_ptr<LAMBDA_GUARD>>(
py_m, "LAMBDA_GUARD")
.def(py::init<py::function, py::list>())
.def("__call__", &LAMBDA_GUARD::check);
py::class_<TYPE_MATCH, LeafGuard, std::shared_ptr<TYPE_MATCH>>(
py_m, "TYPE_MATCH")
.def(py::init<py::object, py::list>())
.def("__call__", &TYPE_MATCH::check);
py::class_<ID_MATCH, LeafGuard, std::shared_ptr<ID_MATCH>>(py_m, "ID_MATCH")
.def(py::init<py::object, py::list>())
.def("__call__", &ID_MATCH::check);
py::class_<EQUALS_MATCH, LeafGuard, std::shared_ptr<EQUALS_MATCH>>(
py_m, "EQUALS_MATCH")
.def(py::init<py::object, py::list>())
.def("__call__", &EQUALS_MATCH::check);
py::class_<LENGTH_CHECK, LeafGuard, std::shared_ptr<LENGTH_CHECK>>(
py_m, "LENGTH_CHECK")
.def(py::init<py::object, py::list>())
.def("__call__", &LENGTH_CHECK::check);
py::class_<DICT_LENGTH, LeafGuard, std::shared_ptr<DICT_LENGTH>>(
py_m, "DICT_LENGTH")
.def(py::init<py::object, py::list>())
.def("__call__", &DICT_LENGTH::check);
py::class_<DEFAULT_DEVICE, LeafGuard, std::shared_ptr<DEFAULT_DEVICE>>(
py_m, "DEFAULT_DEVICE")
.def(py::init<py::list>())
.def("__call__", &DEFAULT_DEVICE::check);
py::class_<NOT_NONE, LeafGuard, std::shared_ptr<NOT_NONE>>(py_m, "NOT_NONE")
.def(py::init<py::list>())
.def("__call__", &NOT_NONE::check);
py::class_<
TUPLE_ITERATOR_LEN,
LeafGuard,
std::shared_ptr<TUPLE_ITERATOR_LEN>>(py_m, "TUPLE_ITERATOR_LEN")
.def(py::init<py::object, py::object, py::list>())
.def("__call__", &TUPLE_ITERATOR_LEN::check);
py::class_<
RANGE_ITERATOR_MATCH,
LeafGuard,
std::shared_ptr<RANGE_ITERATOR_MATCH>>(py_m, "RANGE_ITERATOR_MATCH")
.def(py::init<py::object, py::object, py::object, py::object, py::list>())
.def("__call__", &RANGE_ITERATOR_MATCH::check);
py::class_<GLOBAL_STATE, LeafGuard, std::shared_ptr<GLOBAL_STATE>>(
py_m, "GLOBAL_STATE")
.def(py::init<py::list>())
.def("check_verbose", &GLOBAL_STATE::check_verbose)
.def("__call__", &GLOBAL_STATE::check);
py::class_<
TORCH_FUNCTION_MODE_STACK,
LeafGuard,
std::shared_ptr<TORCH_FUNCTION_MODE_STACK>>(
py_m, "TORCH_FUNCTION_MODE_STACK")
.def(py::init<py::list, py::list>())
.def("__call__", &TORCH_FUNCTION_MODE_STACK::check);
py::class_<DATA_PTR_MATCH, LeafGuard, std::shared_ptr<DATA_PTR_MATCH>>(
py_m, "DATA_PTR_MATCH")
.def(py::init<py::object, py::list>())
.def("__call__", &DATA_PTR_MATCH::check);
py::class_<NO_HASATTR, LeafGuard, std::shared_ptr<NO_HASATTR>>(
py_m, "NO_HASATTR")
.def(py::init<py::object, py::list>())
.def("__call__", &NO_HASATTR::check);
py::class_<DICT_CONTAINS, LeafGuard, std::shared_ptr<DICT_CONTAINS>>(
py_m, "DICT_CONTAINS")
.def(py::init<bool, py::object, py::list>())
.def("__call__", &DICT_CONTAINS::check);
py::class_<DYNAMIC_INDICES, LeafGuard, std::shared_ptr<DYNAMIC_INDICES>>(
py_m, "DYNAMIC_INDICES")
.def(py::init<py::set, py::list>())
.def("__call__", &DYNAMIC_INDICES::check);
py::class_<DICT_VERSION, LeafGuard, std::shared_ptr<DICT_VERSION>>(
py_m, "DICT_VERSION")
.def(py::init<py::object, py::list>())
.def("__call__", &DICT_VERSION::check);
py::class_<TENSOR_MATCH, LeafGuard, std::shared_ptr<TENSOR_MATCH>>(
py_m, "TENSOR_MATCH")
.def(py::init<
RootGuardManager*,
py::object,
py::object,
py::object,
py::str,
py::list>())
.def("__call__", &TENSOR_MATCH::check);
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<OBJECT_ALIASING, LeafGuard, std::shared_ptr<OBJECT_ALIASING>>(
py_m, "OBJECT_ALIASING");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
NO_TENSOR_ALIASING,
LeafGuard,
std::shared_ptr<NO_TENSOR_ALIASING>>(py_m, "NO_TENSOR_ALIASING");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
STORAGE_OVERLAPPING,
LeafGuard,
std::shared_ptr<STORAGE_OVERLAPPING>>(py_m, "STORAGE_OVERLAPPING");
// Guard Accessors - These are present so that we can iterate over the
// GuardManager hierarchy. We intentionally do not provide even an init
// function on these, because these should be constructed from within C++.
py::class_<GuardAccessor, std::unique_ptr<GuardAccessor>>(
py_m, "GuardAccessor")
.def("repr", &GuardAccessor::repr);
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
GetAttrGuardAccessor,
GuardAccessor,
std::unique_ptr<GetAttrGuardAccessor>>(py_m, "GetAttrGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
GetGenericDictGuardAccessor,
GuardAccessor,
std::unique_ptr<GetGenericDictGuardAccessor>>(
py_m, "GetGenericDictGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
GetItemGuardAccessor,
GuardAccessor,
std::unique_ptr<GetItemGuardAccessor>>(py_m, "GetItemGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
DictGetItemGuardAccessor,
GuardAccessor,
std::unique_ptr<DictGetItemGuardAccessor>>(
py_m, "DictGetItemGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
ListGetItemGuardAccessor,
GuardAccessor,
std::unique_ptr<ListGetItemGuardAccessor>>(
py_m, "ListGetItemGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
TupleGetItemGuardAccessor,
GuardAccessor,
std::unique_ptr<TupleGetItemGuardAccessor>>(
py_m, "TupleGetItemGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
FuncDefaultsGuardAccessor,
GuardAccessor,
std::unique_ptr<FuncDefaultsGuardAccessor>>(
py_m, "FuncDefaultsGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
FuncKwDefaultsGuardAccessor,
GuardAccessor,
std::unique_ptr<FuncKwDefaultsGuardAccessor>>(
py_m, "FuncKwDefaultsGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
GlobalsGuardAccessor,
GuardAccessor,
std::unique_ptr<GlobalsGuardAccessor>>(py_m, "GlobalsGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
TypeGuardAccessor,
GuardAccessor,
std::unique_ptr<TypeGuardAccessor>>(py_m, "TypeGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
WeakRefCallGuardAccessor,
GuardAccessor,
std::unique_ptr<WeakRefCallGuardAccessor>>(
py_m, "WeakRefCallGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
CallFunctionNoArgsGuardAccessor,
GuardAccessor,
std::unique_ptr<CallFunctionNoArgsGuardAccessor>>(
py_m, "CallFunctionNoArgsGuardAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
TupleIteratorGetItemAccessor,
GuardAccessor,
std::unique_ptr<TupleIteratorGetItemAccessor>>(
py_m, "TupleIteratorGetItemAccessor");
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<
GlobalWeakRefGuardAccessor,
GuardAccessor,
std::unique_ptr<GlobalWeakRefGuardAccessor>>(
py_m, "GlobalWeakRefGuardAccessor");
// Guard Manager - No constructor in python, python should use
// RootGuardManager.
py::class_<GuardManager, std::unique_ptr<GuardManager>>(py_m, "GuardManager")
// return by reference because GuardManager has the ownership of accessors
.def("get_source", &GuardManager::get_source)
.def("fail_count", &GuardManager::fail_count)
.def(
"get_accessors",
&GuardManager::get_accessors,
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of child
// managers
.def(
"get_child_managers",
&GuardManager::get_child_managers,
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of leaf
// guards
.def(
"get_leaf_guards",
&GuardManager::get_leaf_guards,
py::return_value_policy::reference)
.def(
"add_lambda_guard",
[](GuardManager& self,
py::object lambda,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<LAMBDA_GUARD>(
std::move(lambda), std::move(verbose_code_parts)));
})
.def(
"add_type_match_guard",
[](GuardManager& self,
py::object value,
py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("TYPE_MATCH");
self.add_leaf_guard(std::make_shared<TYPE_MATCH>(
std::move(value), std::move(verbose_code_parts)));
})
.def(
"add_id_match_guard",
[](GuardManager& self,
py::object value,
py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("ID_MATCH");
self.add_leaf_guard(std::make_shared<ID_MATCH>(
std::move(value), std::move(verbose_code_parts)));
})
.def(
"add_equals_match_guard",
[](GuardManager& self,
py::object value,
py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("EQUALS_MATCH");
self.add_leaf_guard(std::make_shared<EQUALS_MATCH>(
std::move(value), std::move(verbose_code_parts)));
})
.def(
"add_length_check_guard",
[](GuardManager& self,
py::object value,
py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("LENGTH_CHECK");
self.add_leaf_guard(std::make_shared<LENGTH_CHECK>(
std::move(value), std::move(verbose_code_parts)));
})
.def(
"add_dict_length_check_guard",
[](GuardManager& self,
py::object value,
py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("DICT_LENGTH");
self.add_leaf_guard(std::make_shared<DICT_LENGTH>(
std::move(value), std::move(verbose_code_parts)));
})
.def(
"add_tuple_iterator_length_guard",
[](GuardManager& self,
py::object length,
py::object type_id,
py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("TUPLE_ITERATOR_LEN");
self.add_leaf_guard(std::make_shared<TUPLE_ITERATOR_LEN>(
std::move(length),
std::move(type_id),
std::move(verbose_code_parts)));
})
.def(
"add_range_iterator_match_guard",
[](GuardManager& self,
py::object start,
py::object stop,
py::object step,
py::object type_id,
py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("RANGE_ITERATOR_MATCH");
self.add_leaf_guard(std::make_shared<RANGE_ITERATOR_MATCH>(
std::move(start),
std::move(stop),
std::move(step),
std::move(type_id),
std::move(verbose_code_parts)));
})
.def(
"add_default_device_guard",
[](GuardManager& self, py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<DEFAULT_DEVICE>(
std::move(verbose_code_parts)));
})
.def(
"add_not_none_guard",
[](GuardManager& self, py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("NOT_NONE");
self.add_leaf_guard(
std::make_shared<NOT_NONE>(std::move(verbose_code_parts)));
})
.def(
"add_global_state_guard",
[](GuardManager& self, py::object verbose_code_parts) -> void {
self.add_leaf_guard(
std::make_shared<GLOBAL_STATE>(std::move(verbose_code_parts)));
})
.def(
"add_torch_function_mode_stack_guard",
[](GuardManager& self,
const py::list& initial_stack,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<TORCH_FUNCTION_MODE_STACK>(
initial_stack, std::move(verbose_code_parts)));
})
.def(
"add_data_ptr_guard",
[](GuardManager& self,
py::object data_ptr,
py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("DATA_PTR_MATCH");
self.add_leaf_guard(std::make_shared<DATA_PTR_MATCH>(
std::move(data_ptr), std::move(verbose_code_parts)));
})
.def(
"add_no_hasattr_guard",
[](GuardManager& self,
py::object attr_name,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<NO_HASATTR>(
std::move(attr_name), std::move(verbose_code_parts)));
})
.def(
"add_dict_contains_guard",
[](GuardManager& self,
bool contains,
py::object key,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<DICT_CONTAINS>(
contains, std::move(key), std::move(verbose_code_parts)));
})
.def(
"add_dynamic_indices_guard",
[](GuardManager& self,
py::set value,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<DYNAMIC_INDICES>(
std::move(value), std::move(verbose_code_parts)));
})
.def(
"add_dict_version_guard",
[](GuardManager& self,
py::object value,
py::object verbose_code_parts) -> void {
self.add_leaf_guard(std::make_shared<DICT_VERSION>(
std::move(value), std::move(verbose_code_parts)));
})
.def(
"add_tensor_match_guard",
[](GuardManager& self,
py::object value,
py::object sizes,
py::object strides,
py::object tensor_name,
py::object verbose_code_parts) -> void {
SKIP_IF_GUARD_ALREADY_PRESENT("TENSOR_MATCH");
self.add_leaf_guard(std::make_shared<TENSOR_MATCH>(
self.get_root(),
std::move(value),
std::move(sizes),
std::move(strides),
std::move(tensor_name),
std::move(verbose_code_parts)));
})
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"getitem_manager",
&GuardManager::get_child_manager<GetItemGuardAccessor>,
py::arg("key"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"dict_getitem_manager",
&GuardManager::get_child_manager<DictGetItemGuardAccessor>,
py::arg("key"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"list_getitem_manager",
&GuardManager::get_child_manager<ListGetItemGuardAccessor>,
py::arg("key"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"tuple_getitem_manager",
&GuardManager::get_child_manager<TupleGetItemGuardAccessor>,
py::arg("key"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"func_defaults_manager",
[](GuardManager& self,
std::string source,
py::object example_value,
py::handle guard_manager_enum) -> GuardManager* {
// A unique key is used to save as the accessor key.
py::str unique_key("__defaults_accessor__");
return self.get_child_manager<FuncDefaultsGuardAccessor>(
std::move(unique_key),
std::move(source),
std::move(example_value),
guard_manager_enum);
},
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"func_kwdefaults_manager",
[](GuardManager& self,
std::string source,
py::object example_value,
py::handle guard_manager_enum) -> GuardManager* {
// A unique key is used to save as the accessor key.
py::str unique_key("__kwdefaults_accessor__");
return self.get_child_manager<FuncKwDefaultsGuardAccessor>(
std::move(unique_key),
std::move(source),
std::move(example_value),
guard_manager_enum);
},
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"globals_dict_manager",
&GuardManager::get_child_manager<GlobalsGuardAccessor>,
py::arg("f_globals"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"type_manager",
[](GuardManager& self,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) -> GuardManager* {
// A unique key is used to save as the accessor key.
py::str unique_key("__type_accessor__");
return self.get_child_manager<TypeGuardAccessor>(
std::move(unique_key),
std::move(source),
example_value,
guard_manager_enum);
},
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"weakref_call_manager",
[](GuardManager& self,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) -> GuardManager* {
// A unique key is used to save as the accessor key.
py::str unique_key("__weakref_call_accessor__");
return self.get_child_manager<WeakRefCallGuardAccessor>(
std::move(unique_key),
std::move(source),
example_value,
guard_manager_enum);
},
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"call_function_no_args_manager",
[](GuardManager& self,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) -> GuardManager* {
// A unique key is used to save as the accessor key.
py::str unique_key("__call_function_no_args_accessor__");
return self.get_child_manager<CallFunctionNoArgsGuardAccessor>(
std::move(unique_key),
std::move(source),
example_value,
guard_manager_enum);
},
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"tuple_iterator_getitem_manager",
&GuardManager::get_child_manager<TupleIteratorGetItemAccessor>,
py::arg("index"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"global_weakref_manager",
&GuardManager::get_child_manager<GlobalWeakRefGuardAccessor>,
py::arg("global_name"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"lambda_manager",
&GuardManager::get_child_manager<PythonLambdaGuardAccessor>,
py::arg("python_lambda"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"grad_manager",
[](GuardManager& self,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) -> GuardManager* {
// A unique key is used to save as the accessor key.
py::str unique_key("__grad_accessor__");
return self.get_child_manager<GradGuardAccessor>(
std::move(unique_key),
std::move(source),
example_value,
guard_manager_enum);
},
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"get_generic_dict_manager",
[](GuardManager& self,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) -> GuardManager* {
// A unique key is used to save as the accessor key.
py::str unique_key("__generic_dict_accessor__");
return self.get_child_manager<GetGenericDictGuardAccessor>(
std::move(unique_key),
std::move(source),
example_value,
guard_manager_enum);
},
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because C++ GuardManager has the ownership of
// accessors and guard managers
.def(
"getattr_manager",
&GuardManager::get_child_manager<GetAttrGuardAccessor>,
py::arg("attr"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference);
// Root Guard Manager
py::class_<RootGuardManager, GuardManager, std::unique_ptr<RootGuardManager>>(
py_m, "RootGuardManager")
.def(py::init<>())
.def("check", &RootGuardManager::check)
.def("check_verbose", &RootGuardManager::check_verbose)
.def(
"clone_manager",
&RootGuardManager::clone_manager,
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of leaf
// guards
.def(
"get_epilogue_lambda_guards",
&RootGuardManager::get_epilogue_lambda_guards,
py::return_value_policy::reference)
.def(
"add_epilogue_lambda_guard",
[](RootGuardManager& self,
py::object lambda,
py::object verbose_code_parts) -> void {
self.add_epilogue_lambda_guard(std::make_unique<LAMBDA_GUARD>(
std::move(lambda), std::move(verbose_code_parts)));
});
// Dict Guard Manager
py::class_<DictGuardManager, GuardManager, std::unique_ptr<DictGuardManager>>(
py_m, "DictGuardManager")
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"get_key_manager",
[](DictGuardManager& self,
py::object index,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) -> GuardManager* {
return self.get_key_manager(
std::move(index),
std::move(source),
example_value,
guard_manager_enum);
},
py::arg("index"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"get_value_manager",
[](DictGuardManager& self,
py::object index,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) -> GuardManager* {
return self.get_value_manager(
std::move(index),
std::move(source),
example_value,
guard_manager_enum);
},
py::arg("index"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference)
// return by reference because GuardManager has the ownership of leaf
// guards
.def(
"get_key_value_managers",
&DictGuardManager::get_key_value_managers,
py::return_value_policy::reference)
// Skipped leaf guards
.def("add_type_match_guard", &DictGuardManager::skip_adding_guard)
.def("add_dict_length_check_guard", &DictGuardManager::skip_adding_guard)
// Permitted leaf guards
.def(
"add_dict_contains_guard",
[](DictGuardManager& self,
bool contains,
py::object key,
py::object verbose_code_parts) -> void {
self.add_permitted_leaf_guard(std::make_shared<DICT_CONTAINS>(
contains, std::move(key), std::move(verbose_code_parts)));
})
.def(
"add_dict_version_guard",
[](DictGuardManager& self,
py::object value,
py::object verbose_code_parts) -> void {
// DICT_VERSION is used in a very narrow context today to guard on
// pytree SUPPPORTED_NODES. We can remove this once we have tags in
// DictGuardManager.
self.add_permitted_leaf_guard(std::make_shared<DICT_VERSION>(
std::move(value), std::move(verbose_code_parts)));
})
// Not permitted accesssors
.def("lambda_manager", &DictGuardManager::fail_on_get_child_manager)
.def("getitem_manager", &DictGuardManager::fail_on_get_child_manager)
.def("dict_getitem_manager", &DictGuardManager::fail_on_get_child_manager)
.def("globals_dict_manager", &DictGuardManager::fail_on_get_child_manager)
.def(
"tuple_iterator_getitem_manager",
&DictGuardManager::fail_on_get_child_manager)
.def(
"global_weakref_manager",
&DictGuardManager::fail_on_get_child_manager)
.def("lambda_manager", &DictGuardManager::fail_on_get_child_manager)
// Permitted accessors (and also type_manager)
// return by reference because GuardManager has the ownership of accessors
// and guard managers
.def(
"getattr_manager",
[](DictGuardManager& self,
py::object attr_name,
std::string source,
py::handle example_value,
py::handle guard_manager_enum) -> GuardManager* {
if (self.is_exact_dict_type()) {
throw std::runtime_error(
"getattr_manager on a DictGuardManager is supported only for dict subclasses");
}
return self.get_child_manager<GetAttrGuardAccessor>(
std::move(attr_name),
std::move(source),
example_value,
guard_manager_enum);
},
py::arg("attr"),
py::arg("source"),
py::arg("example_value"),
py::arg("guard_manager_enum"),
py::return_value_policy::reference);
// Dict Guard Manager
py::class_< // NOLINT
DictSubclassGuardManager,
DictGuardManager,
std::unique_ptr<DictSubclassGuardManager>>(
py_m, "DictSubclassGuardManager") // NOLINT
.def(
"add_no_hasattr_guard",
[](DictSubclassGuardManager& self,
py::object attr_name,
py::object verbose_code_parts) -> void {
self.add_permitted_leaf_guard(std::make_shared<NO_HASATTR>(
std::move(attr_name), std::move(verbose_code_parts)));
});
py_m.def("install_object_aliasing_guard", install_object_aliasing_guard);
py_m.def(
"install_no_tensor_aliasing_guard", install_no_tensor_aliasing_guard);
py_m.def(
"install_storage_overlapping_guard", install_storage_overlapping_guard);
py_m.def(
"compute_overlapping_tensors",
[](const std::vector<Tensor> tensors, bool symbolic) {
// Pick the correct Meta class, depending on whether we are
// dealing with symbolic values or not.
if (symbolic) {
return compute_overlapping_tensors<DynamicMeta>(tensors);
} else {
return compute_overlapping_tensors<StaticMeta>(tensors);
}
},
py::arg("tensors"),
py::arg("symbolic") = true);
py_m.def("profile_guard_manager", profile_guard_manager);
// initialize dict_version_map watcher for 3.12
#if IS_PYTHON_3_12_PLUS
dict_version_watcher_id = PyDict_AddWatcher(dict_version_watch_callback);
if (dict_version_watcher_id == -1) {
throw std::runtime_error("Failed to install dict_version_watch_callback");
}
#endif
return m;
}
} // namespace torch::dynamo
|